Browse Source

lnchannel: start using "latest" and "next" instead of "current" and "pending"

"current" used to be "oldest_unrevoked"; and pending was "oldest_unrevoked + 1"
but this was very confusing...
so now we have "oldest_unrevoked", "latest", and "next"
where "next" is "latest + 1"
"oldest_unrevoked" and "latest" are either the same or are offset by 1
(but caller should know which one they need)

rm "got_sig_for_next" - it was a redundant sanity check, that really
just complicated things

rm "local_commitment", "remote_commitment", "set_local_commitment",
"set_remote_commitment" - just use "get_latest_commitment" instead
dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 5 years ago
committed by ThomasV
parent
commit
b1f606eaed
  1. 78
      electrum/lnchannel.py
  2. 41
      electrum/lnpeer.py
  3. 11
      electrum/lnsweep.py
  4. 1
      electrum/lnutil.py
  5. 6
      electrum/lnworker.py
  6. 143
      electrum/tests/test_lnchannel.py

78
electrum/lnchannel.py

@ -146,8 +146,6 @@ class Channel(Logger):
self._is_funding_txo_spent = None # "don't know" self._is_funding_txo_spent = None # "don't know"
self._state = None self._state = None
self.set_state('DISCONNECTED') self.set_state('DISCONNECTED')
self.local_commitment = None
self.remote_commitment = None
self.sweep_info = {} self.sweep_info = {}
def get_feerate(self, subject, ctn): def get_feerate(self, subject, ctn):
@ -175,14 +173,6 @@ class Channel(Logger):
out[rhash] = (self.channel_id, htlc, direction, status) out[rhash] = (self.channel_id, htlc, direction, status)
return out return out
def set_local_commitment(self, ctx):
ctn = extract_ctn_from_tx_and_chan(ctx, self)
assert self.signature_fits(ctx), (self.hm.log[LOCAL])
self.local_commitment = ctx
def set_remote_commitment(self):
self.remote_commitment = self.current_commitment(REMOTE)
def open_with_first_pcp(self, remote_pcp, remote_sig): def open_with_first_pcp(self, remote_pcp, remote_sig):
self.config[REMOTE] = self.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_pcp, next_per_commitment_point=None) self.config[REMOTE] = self.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_pcp, next_per_commitment_point=None)
self.config[LOCAL] = self.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) self.config[LOCAL] = self.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig)
@ -285,10 +275,10 @@ class Channel(Logger):
This docstring was adapted from LND. This docstring was adapted from LND.
""" """
next_remote_ctn = self.get_current_ctn(REMOTE) + 1 next_remote_ctn = self.get_next_ctn(REMOTE)
self.logger.info(f"sign_next_commitment {next_remote_ctn}") self.logger.info(f"sign_next_commitment {next_remote_ctn}")
self.hm.send_ctx()
pending_remote_commitment = self.pending_commitment(REMOTE) pending_remote_commitment = self.get_next_commitment(REMOTE)
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE]) sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
their_remote_htlc_privkey_number = derive_privkey( their_remote_htlc_privkey_number = derive_privkey(
@ -317,8 +307,7 @@ class Channel(Logger):
htlcsigs.sort() htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs] htlcsigs = [x[1] for x in htlcsigs]
# TODO should add remote_commitment here and handle self.hm.send_ctx()
# both valid ctx'es in lnwatcher at the same time...
return sig_64, htlcsigs return sig_64, htlcsigs
@ -335,22 +324,20 @@ class Channel(Logger):
This docstring is from LND. This docstring is from LND.
""" """
next_local_ctn = self.get_next_ctn(LOCAL)
self.logger.info("receive_new_commitment") self.logger.info("receive_new_commitment")
self.hm.recv_ctx()
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
pending_local_commitment = self.pending_commitment(LOCAL) pending_local_commitment = self.get_next_commitment(LOCAL)
preimage_hex = pending_local_commitment.serialize_preimage(0) preimage_hex = pending_local_commitment.serialize_preimage(0)
pre_hash = sha256d(bfh(preimage_hex)) pre_hash = sha256d(bfh(preimage_hex))
if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, sig, pre_hash): if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, sig, pre_hash):
raise Exception('failed verifying signature of our updated commitment transaction: ' + bh2u(sig) + ' preimage is ' + preimage_hex) raise Exception(f'failed verifying signature of our updated commitment transaction: {bh2u(sig)} preimage is {preimage_hex}')
htlc_sigs_string = b''.join(htlc_sigs) htlc_sigs_string = b''.join(htlc_sigs)
htlc_sigs = htlc_sigs[:] # copy cause we will delete now htlc_sigs = htlc_sigs[:] # copy cause we will delete now
next_local_ctn = self.get_current_ctn(LOCAL) + 1
for htlcs, we_receive in [(self.included_htlcs(LOCAL, SENT, ctn=next_local_ctn), False), for htlcs, we_receive in [(self.included_htlcs(LOCAL, SENT, ctn=next_local_ctn), False),
(self.included_htlcs(LOCAL, RECEIVED, ctn=next_local_ctn), True)]: (self.included_htlcs(LOCAL, RECEIVED, ctn=next_local_ctn), True)]:
for htlc in htlcs: for htlc in htlcs:
@ -359,12 +346,10 @@ class Channel(Logger):
if len(htlc_sigs) != 0: # all sigs should have been popped above if len(htlc_sigs) != 0: # all sigs should have been popped above
raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures') raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures')
self.hm.recv_ctx()
self.config[LOCAL]=self.config[LOCAL]._replace( self.config[LOCAL]=self.config[LOCAL]._replace(
current_commitment_signature=sig, current_commitment_signature=sig,
current_htlc_signatures=htlc_sigs_string, current_htlc_signatures=htlc_sigs_string)
got_sig_for_next=True)
self.set_local_commitment(pending_local_commitment)
def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int: def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int:
ctn = extract_ctn_from_tx_and_chan(ctx, self) ctn = extract_ctn_from_tx_and_chan(ctx, self)
@ -394,15 +379,14 @@ class Channel(Logger):
def revoke_current_commitment(self): def revoke_current_commitment(self):
self.logger.info("revoke_current_commitment") self.logger.info("revoke_current_commitment")
assert self.config[LOCAL].got_sig_for_next
new_ctn = self.config[LOCAL].ctn + 1 new_ctn = self.config[LOCAL].ctn + 1
new_ctx = self.pending_commitment(LOCAL) new_ctx = self.get_latest_commitment(LOCAL)
assert self.signature_fits(new_ctx) if not self.signature_fits(new_ctx):
self.set_local_commitment(new_ctx) # this should never fail; as receive_new_commitment already did this test
raise Exception("refusing to revoke as remote sig does not fit")
self.hm.send_rev() self.hm.send_rev()
self.config[LOCAL]=self.config[LOCAL]._replace( self.config[LOCAL]=self.config[LOCAL]._replace(
ctn=new_ctn, ctn=new_ctn,
got_sig_for_next=False,
) )
received = self.hm.received_in_ctn(new_ctn) received = self.hm.received_in_ctn(new_ctn)
sent = self.hm.sent_in_ctn(new_ctn) sent = self.hm.sent_in_ctn(new_ctn)
@ -427,14 +411,12 @@ class Channel(Logger):
self.config[REMOTE].revocation_store.add_next_entry(revocation.per_commitment_secret) self.config[REMOTE].revocation_store.add_next_entry(revocation.per_commitment_secret)
##### start applying fee/htlc changes ##### start applying fee/htlc changes
next_point = self.config[REMOTE].next_per_commitment_point
self.hm.recv_rev() self.hm.recv_rev()
self.config[REMOTE]=self.config[REMOTE]._replace( self.config[REMOTE]=self.config[REMOTE]._replace(
ctn=self.config[REMOTE].ctn + 1, ctn=self.config[REMOTE].ctn + 1,
current_per_commitment_point=next_point, current_per_commitment_point=self.config[REMOTE].next_per_commitment_point,
next_per_commitment_point=revocation.next_per_commitment_point, next_per_commitment_point=revocation.next_per_commitment_point,
) )
self.set_remote_commitment()
def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None): def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None):
""" """
@ -512,7 +494,7 @@ class Channel(Logger):
def get_secret_and_point(self, subject, ctn) -> Tuple[Optional[bytes], bytes]: def get_secret_and_point(self, subject, ctn) -> Tuple[Optional[bytes], bytes]:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
offset = ctn - self.get_current_ctn(subject) offset = ctn - self.get_oldest_unrevoked_ctn(subject)
if subject == REMOTE: if subject == REMOTE:
if offset > 1: if offset > 1:
raise RemoteCtnTooFarInFuture(f"offset: {offset}") raise RemoteCtnTooFarInFuture(f"offset: {offset}")
@ -540,12 +522,16 @@ class Channel(Logger):
secret, ctx = self.get_secret_and_commitment(subject, ctn) secret, ctx = self.get_secret_and_commitment(subject, ctn)
return ctx return ctx
def pending_commitment(self, subject): def get_next_commitment(self, subject: HTLCOwner) -> Transaction:
ctn = self.get_current_ctn(subject) ctn = self.get_next_ctn(subject)
return self.get_commitment(subject, ctn + 1) return self.get_commitment(subject, ctn)
def get_latest_commitment(self, subject: HTLCOwner) -> Transaction:
ctn = self.get_latest_ctn(subject)
return self.get_commitment(subject, ctn)
def current_commitment(self, subject): def get_oldest_unrevoked_commitment(self, subject: HTLCOwner) -> Transaction:
ctn = self.get_current_ctn(subject) ctn = self.get_oldest_unrevoked_ctn(subject)
return self.get_commitment(subject, ctn) return self.get_commitment(subject, ctn)
def create_sweeptxs(self, ctn): def create_sweeptxs(self, ctn):
@ -553,9 +539,15 @@ class Channel(Logger):
secret, ctx = self.get_secret_and_commitment(REMOTE, ctn) secret, ctx = self.get_secret_and_commitment(REMOTE, ctn)
return create_sweeptxs_for_watchtower(self, ctx, secret, self.sweep_address) return create_sweeptxs_for_watchtower(self, ctx, secret, self.sweep_address)
def get_current_ctn(self, subject): def get_oldest_unrevoked_ctn(self, subject: HTLCOwner) -> int:
return self.config[subject].ctn return self.config[subject].ctn
def get_latest_ctn(self, subject: HTLCOwner) -> int:
return self.hm.ctn_latest(subject)
def get_next_ctn(self, subject: HTLCOwner) -> int:
return self.hm.ctn_latest(subject) + 1
def total_msat(self, direction): def total_msat(self, direction):
"""Return the cumulative total msat amount received/sent so far.""" """Return the cumulative total msat amount received/sent so far."""
assert type(direction) is Direction assert type(direction) is Direction
@ -597,12 +589,8 @@ class Channel(Logger):
self.logger.info("receive_fail_htlc") self.logger.info("receive_fail_htlc")
self.hm.recv_fail(htlc_id) self.hm.recv_fail(htlc_id)
@property
def current_height(self):
return {LOCAL: self.config[LOCAL].ctn, REMOTE: self.config[REMOTE].ctn}
def pending_local_fee(self): def pending_local_fee(self):
return self.constraints.capacity - sum(x[2] for x in self.pending_commitment(LOCAL).outputs()) return self.constraints.capacity - sum(x[2] for x in self.get_next_commitment(LOCAL).outputs())
def update_fee(self, feerate: int, from_us: bool): def update_fee(self, feerate: int, from_us: bool):
# feerate uses sat/kw # feerate uses sat/kw
@ -751,7 +739,7 @@ class Channel(Logger):
return res return res
def force_close_tx(self): def force_close_tx(self):
tx = self.local_commitment tx = self.get_latest_commitment(LOCAL)
assert self.signature_fits(tx) assert self.signature_fits(tx)
tx = Transaction(str(tx)) tx = Transaction(str(tx))
tx.deserialize(True) tx.deserialize(True)

41
electrum/lnpeer.py

@ -458,7 +458,6 @@ class Peer(Logger):
was_announced=False, was_announced=False,
current_commitment_signature=None, current_commitment_signature=None,
current_htlc_signatures=[], current_htlc_signatures=[],
got_sig_for_next=False,
) )
return local_config return local_config
@ -577,8 +576,6 @@ class Peer(Logger):
# broadcast funding tx # broadcast funding tx
await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), 5) await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), 5)
chan.open_with_first_pcp(remote_per_commitment_point, remote_sig) chan.open_with_first_pcp(remote_per_commitment_point, remote_sig)
chan.set_remote_commitment()
chan.set_local_commitment(chan.current_commitment(LOCAL))
return chan return chan
async def on_open_channel(self, payload): async def on_open_channel(self, payload):
@ -713,12 +710,12 @@ class Peer(Logger):
# BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side" # BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side"
chan.hm.discard_unsigned_remote_updates() chan.hm.discard_unsigned_remote_updates()
# ctns # ctns
oldest_unrevoked_local_ctn = chan.config[LOCAL].ctn oldest_unrevoked_local_ctn = chan.get_oldest_unrevoked_ctn(LOCAL)
latest_local_ctn = chan.hm.ctn_latest(LOCAL) latest_local_ctn = chan.get_latest_ctn(LOCAL)
next_local_ctn = latest_local_ctn + 1 next_local_ctn = chan.get_next_ctn(LOCAL)
oldest_unrevoked_remote_ctn = chan.config[REMOTE].ctn oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
latest_remote_ctn = chan.hm.ctn_latest(REMOTE) latest_remote_ctn = chan.get_latest_ctn(REMOTE)
next_remote_ctn = latest_remote_ctn + 1 next_remote_ctn = chan.get_next_ctn(REMOTE)
# send message # send message
dlp_enabled = self.localfeatures & LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT dlp_enabled = self.localfeatures & LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
if dlp_enabled: if dlp_enabled:
@ -1016,7 +1013,7 @@ class Peer(Logger):
htlc_id = int.from_bytes(payload["id"], "big") htlc_id = int.from_bytes(payload["id"], "big")
chan = self.channels[channel_id] chan = self.channels[channel_id]
chan.receive_fail_htlc(htlc_id) chan.receive_fail_htlc(htlc_id)
local_ctn = chan.get_current_ctn(LOCAL) local_ctn = chan.get_latest_ctn(LOCAL)
asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id)) asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id))
asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn)) asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn))
@ -1087,7 +1084,7 @@ class Peer(Logger):
self.network.path_finder.add_to_blacklist(short_chan_id) self.network.path_finder.add_to_blacklist(short_chan_id)
def maybe_send_commitment(self, chan: Channel): def maybe_send_commitment(self, chan: Channel):
ctn_to_sign = chan.get_current_ctn(REMOTE) + 1 ctn_to_sign = chan.get_next_ctn(REMOTE)
# if there are no changes, we will not (and must not) send a new commitment # if there are no changes, we will not (and must not) send a new commitment
next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE) next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE)
if (next_htlcs == latest_htlcs if (next_htlcs == latest_htlcs
@ -1101,12 +1098,12 @@ class Peer(Logger):
async def await_remote(self, chan: Channel, ctn: int): async def await_remote(self, chan: Channel, ctn: int):
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
while chan.get_current_ctn(REMOTE) <= ctn: while chan.get_latest_ctn(REMOTE) <= ctn:
await self._remote_changed_events[chan.channel_id].wait() await self._remote_changed_events[chan.channel_id].wait()
async def await_local(self, chan: Channel, ctn: int): async def await_local(self, chan: Channel, ctn: int):
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
while chan.get_current_ctn(LOCAL) <= ctn: while chan.get_latest_ctn(LOCAL) <= ctn:
await self._local_changed_events[chan.channel_id].wait() await self._local_changed_events[chan.channel_id].wait()
async def pay(self, route: List['RouteEdge'], chan: Channel, amount_msat: int, async def pay(self, route: List['RouteEdge'], chan: Channel, amount_msat: int,
@ -1122,7 +1119,7 @@ class Peer(Logger):
# create htlc # create htlc
htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time())) htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time()))
htlc = chan.add_htlc(htlc) htlc = chan.add_htlc(htlc)
remote_ctn = chan.get_current_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
chan.onion_keys[htlc.htlc_id] = secret_key chan.onion_keys[htlc.htlc_id] = secret_key
self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route
self.logger.info(f"starting payment. route: {route}. htlc: {htlc}") self.logger.info(f"starting payment. route: {route}. htlc: {htlc}")
@ -1156,7 +1153,7 @@ class Peer(Logger):
and chan.get_next_feerate(LOCAL) == chan.get_latest_feerate(LOCAL)): and chan.get_next_feerate(LOCAL) == chan.get_latest_feerate(LOCAL)):
raise RemoteMisbehaving('received commitment_signed without pending changes') raise RemoteMisbehaving('received commitment_signed without pending changes')
# make sure ctn is new # make sure ctn is new
ctn_to_recv = chan.get_current_ctn(LOCAL) + 1 ctn_to_recv = chan.get_next_ctn(LOCAL)
if ctn_to_recv == self.recv_commitment_for_ctn_last[chan]: if ctn_to_recv == self.recv_commitment_for_ctn_last[chan]:
raise RemoteMisbehaving('received commitment_signed with same ctn') raise RemoteMisbehaving('received commitment_signed with same ctn')
self.recv_commitment_for_ctn_last[chan] = ctn_to_recv self.recv_commitment_for_ctn_last[chan] = ctn_to_recv
@ -1172,7 +1169,7 @@ class Peer(Logger):
preimage = update_fulfill_htlc_msg["payment_preimage"] preimage = update_fulfill_htlc_msg["payment_preimage"]
htlc_id = int.from_bytes(update_fulfill_htlc_msg["id"], "big") htlc_id = int.from_bytes(update_fulfill_htlc_msg["id"], "big")
chan.receive_htlc_settle(preimage, htlc_id) chan.receive_htlc_settle(preimage, htlc_id)
local_ctn = chan.get_current_ctn(LOCAL) local_ctn = chan.get_latest_ctn(LOCAL)
asyncio.ensure_future(self._on_update_fulfill_htlc(chan, htlc_id, preimage, local_ctn)) asyncio.ensure_future(self._on_update_fulfill_htlc(chan, htlc_id, preimage, local_ctn))
@log_exceptions @log_exceptions
@ -1206,8 +1203,8 @@ class Peer(Logger):
timestamp=int(time.time()), timestamp=int(time.time()),
htlc_id=htlc_id) htlc_id=htlc_id)
htlc = chan.receive_htlc(htlc) htlc = chan.receive_htlc(htlc)
local_ctn = chan.get_current_ctn(LOCAL) local_ctn = chan.get_latest_ctn(LOCAL)
remote_ctn = chan.get_current_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
if processed_onion.are_we_final: if processed_onion.are_we_final:
asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan, asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan,
htlc=htlc, htlc=htlc,
@ -1243,7 +1240,7 @@ class Peer(Logger):
next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big') next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big')
next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry, timestamp=int(time.time())) next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry, timestamp=int(time.time()))
next_htlc = next_chan.add_htlc(next_htlc) next_htlc = next_chan.add_htlc(next_htlc)
next_remote_ctn = next_chan.get_current_ctn(REMOTE) next_remote_ctn = next_chan.get_latest_ctn(REMOTE)
next_peer.send_message( next_peer.send_message(
"update_add_htlc", "update_add_htlc",
channel_id=next_chan.channel_id, channel_id=next_chan.channel_id,
@ -1301,7 +1298,7 @@ class Peer(Logger):
async def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): async def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
chan.settle_htlc(preimage, htlc_id) chan.settle_htlc(preimage, htlc_id)
remote_ctn = chan.get_current_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
self.send_message("update_fulfill_htlc", self.send_message("update_fulfill_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
id=htlc_id, id=htlc_id,
@ -1313,7 +1310,7 @@ class Peer(Logger):
reason: OnionRoutingFailureMessage): reason: OnionRoutingFailureMessage):
self.logger.info(f"failing received htlc {(bh2u(chan.channel_id), htlc_id)}. reason: {reason}") self.logger.info(f"failing received htlc {(bh2u(chan.channel_id), htlc_id)}. reason: {reason}")
chan.fail_htlc(htlc_id) chan.fail_htlc(htlc_id)
remote_ctn = chan.get_current_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey) error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey)
self.send_message("update_fail_htlc", self.send_message("update_fail_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
@ -1357,7 +1354,7 @@ class Peer(Logger):
else: else:
return return
chan.update_fee(feerate_per_kw, True) chan.update_fee(feerate_per_kw, True)
remote_ctn = chan.get_current_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
self.send_message("update_fee", self.send_message("update_fee",
channel_id=chan.channel_id, channel_id=chan.channel_id,
feerate_per_kw=feerate_per_kw) feerate_per_kw=feerate_per_kw)

11
electrum/lnsweep.py

@ -188,7 +188,7 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
# other outputs are htlcs # other outputs are htlcs
# if they are spent, we need to generate the script # if they are spent, we need to generate the script
# so, second-stage htlc sweep should not be returned here # so, second-stage htlc sweep should not be returned here
if ctn != our_conf.ctn: if ctn < chan.get_oldest_unrevoked_ctn(LOCAL):
_logger.info("we breached.") _logger.info("we breached.")
return {} return {}
txs = {} txs = {}
@ -247,17 +247,18 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
def analyze_ctx(chan: 'Channel', ctx: Transaction): def analyze_ctx(chan: 'Channel', ctx: Transaction):
# note: the remote sometimes has two valid non-revoked commitment transactions, # note: the remote sometimes has two valid non-revoked commitment transactions,
# either of which could be broadcast (their_conf.ctn, their_conf.ctn+1) # either of which could be broadcast
our_conf, their_conf = get_ordered_channel_configs(chan=chan, for_us=True) our_conf, their_conf = get_ordered_channel_configs(chan=chan, for_us=True)
ctn = extract_ctn_from_tx_and_chan(ctx, chan) ctn = extract_ctn_from_tx_and_chan(ctx, chan)
per_commitment_secret = None per_commitment_secret = None
if ctn == their_conf.ctn: oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
if ctn == oldest_unrevoked_remote_ctn:
their_pcp = their_conf.current_per_commitment_point their_pcp = their_conf.current_per_commitment_point
is_revocation = False is_revocation = False
elif ctn == their_conf.ctn + 1: elif ctn == oldest_unrevoked_remote_ctn + 1:
their_pcp = their_conf.next_per_commitment_point their_pcp = their_conf.next_per_commitment_point
is_revocation = False is_revocation = False
elif ctn < their_conf.ctn: # breach elif ctn < oldest_unrevoked_remote_ctn: # breach
try: try:
per_commitment_secret = their_conf.revocation_store.retrieve_secret(RevocationStore.START_INDEX - ctn) per_commitment_secret = their_conf.revocation_store.retrieve_secret(RevocationStore.START_INDEX - ctn)
except UnableToDeriveSecret: except UnableToDeriveSecret:

1
electrum/lnutil.py

@ -52,7 +52,6 @@ class LocalConfig(NamedTuple):
was_announced: bool was_announced: bool
current_commitment_signature: Optional[bytes] current_commitment_signature: Optional[bytes]
current_htlc_signatures: List[bytes] current_htlc_signatures: List[bytes]
got_sig_for_next: bool
class RemoteConfig(NamedTuple): class RemoteConfig(NamedTuple):

6
electrum/lnworker.py

@ -311,8 +311,6 @@ class LNWallet(LNWorker):
for x in wallet.storage.get("channels", []): for x in wallet.storage.get("channels", []):
c = Channel(x, sweep_address=self.sweep_address, lnworker=self) c = Channel(x, sweep_address=self.sweep_address, lnworker=self)
self.channels[c.channel_id] = c self.channels[c.channel_id] = c
c.set_remote_commitment()
c.set_local_commitment(c.current_commitment(LOCAL))
# timestamps of opening and closing transactions # timestamps of opening and closing transactions
self.channel_timestamps = self.storage.get('lightning_channel_timestamps', {}) self.channel_timestamps = self.storage.get('lightning_channel_timestamps', {})
self.pending_payments = defaultdict(asyncio.Future) self.pending_payments = defaultdict(asyncio.Future)
@ -348,10 +346,10 @@ class LNWallet(LNWorker):
self.logger.info(f'could not contact remote watchtower {watchtower_url}') self.logger.info(f'could not contact remote watchtower {watchtower_url}')
await asyncio.sleep(5) await asyncio.sleep(5)
async def sync_channel_with_watchtower(self, chan, watchtower): async def sync_channel_with_watchtower(self, chan: Channel, watchtower):
outpoint = chan.funding_outpoint.to_str() outpoint = chan.funding_outpoint.to_str()
addr = chan.get_funding_address() addr = chan.get_funding_address()
current_ctn = chan.get_current_ctn(REMOTE) current_ctn = chan.get_oldest_unrevoked_ctn(REMOTE)
watchtower_ctn = await watchtower.get_ctn(outpoint, addr) watchtower_ctn = await watchtower.get_ctn(outpoint, addr)
for ctn in range(watchtower_ctn + 1, current_ctn): for ctn in range(watchtower_ctn + 1, current_ctn):
sweeptxs = chan.create_sweeptxs(ctn) sweeptxs = chan.create_sweeptxs(ctn)

143
electrum/tests/test_lnchannel.py

@ -85,7 +85,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
was_announced=False, was_announced=False,
current_commitment_signature=None, current_commitment_signature=None,
current_htlc_signatures=None, current_htlc_signatures=None,
got_sig_for_next=False,
), ),
"constraints":lnpeer.ChannelConstraints( "constraints":lnpeer.ChannelConstraints(
capacity=funding_sat, capacity=funding_sat,
@ -93,7 +92,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
funding_txn_minimum_depth=3, funding_txn_minimum_depth=3,
), ),
"node_id":other_node_id, "node_id":other_node_id,
"remote_commitment_to_be_revoked": None,
'onion_keys': {}, 'onion_keys': {},
} }
@ -137,8 +135,8 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice.set_state('OPEN') alice.set_state('OPEN')
bob.set_state('OPEN') bob.set_state('OPEN')
a_out = alice.current_commitment(LOCAL).outputs() a_out = alice.get_latest_commitment(LOCAL).outputs()
b_out = bob.pending_commitment(REMOTE).outputs() b_out = bob.get_next_commitment(REMOTE).outputs()
assert a_out == b_out, "\n" + pformat((a_out, b_out)) assert a_out == b_out, "\n" + pformat((a_out, b_out))
sig_from_bob, a_htlc_sigs = bob.sign_next_commitment() sig_from_bob, a_htlc_sigs = bob.sign_next_commitment()
@ -150,21 +148,12 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice.config[LOCAL] = alice.config[LOCAL]._replace(current_commitment_signature=sig_from_bob) alice.config[LOCAL] = alice.config[LOCAL]._replace(current_commitment_signature=sig_from_bob)
bob.config[LOCAL] = bob.config[LOCAL]._replace(current_commitment_signature=sig_from_alice) bob.config[LOCAL] = bob.config[LOCAL]._replace(current_commitment_signature=sig_from_alice)
alice.set_local_commitment(alice.current_commitment(LOCAL))
bob.set_local_commitment(bob.current_commitment(LOCAL))
alice_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) alice_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
bob_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) bob_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first) alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first)
bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first) bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first)
alice.set_remote_commitment()
bob.set_remote_commitment()
alice.remote_commitment_to_be_revoked = alice.remote_commitment
bob.remote_commitment_to_be_revoked = bob.remote_commitment
alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0) alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0)
bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0) bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0)
alice.hm.channel_open_finished() alice.hm.channel_open_finished()
@ -179,7 +168,7 @@ class TestFee(unittest.TestCase):
""" """
def test_fee(self): def test_fee(self):
alice_channel, bob_channel = create_test_channels(253, 10000000000, 5000000000) alice_channel, bob_channel = create_test_channels(253, 10000000000, 5000000000)
self.assertIn(9999817, [x[2] for x in alice_channel.local_commitment.outputs()]) self.assertIn(9999817, [x[2] for x in alice_channel.get_latest_commitment(LOCAL).outputs()])
class TestChannel(unittest.TestCase): class TestChannel(unittest.TestCase):
maxDiff = 999 maxDiff = 999
@ -228,31 +217,43 @@ class TestChannel(unittest.TestCase):
self.htlc_dict['amount_msat'] += 1000 self.htlc_dict['amount_msat'] += 1000
self.bob_channel.add_htlc(self.htlc_dict) self.bob_channel.add_htlc(self.htlc_dict)
self.alice_channel.receive_htlc(self.htlc_dict) self.alice_channel.receive_htlc(self.htlc_dict)
self.assertEqual(len(self.alice_channel.get_latest_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(self.alice_channel.get_next_commitment(LOCAL).outputs()), 3)
self.assertEqual(len(self.alice_channel.get_latest_commitment(REMOTE).outputs()), 2)
self.assertEqual(len(self.alice_channel.get_next_commitment(REMOTE).outputs()), 3)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(self.alice_channel.get_latest_commitment(LOCAL).outputs()), 3)
self.assertEqual(len(self.alice_channel.get_next_commitment(LOCAL).outputs()), 3)
self.assertEqual(len(self.alice_channel.get_latest_commitment(REMOTE).outputs()), 2)
self.assertEqual(len(self.alice_channel.get_next_commitment(REMOTE).outputs()), 3)
self.alice_channel.revoke_current_commitment() self.alice_channel.revoke_current_commitment()
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)
self.assertEqual(len(self.alice_channel.get_latest_commitment(LOCAL).outputs()), 3)
self.assertEqual(len(self.alice_channel.get_next_commitment(LOCAL).outputs()), 3)
self.assertEqual(len(self.alice_channel.get_latest_commitment(REMOTE).outputs()), 2)
self.assertEqual(len(self.alice_channel.get_next_commitment(REMOTE).outputs()), 4)
def test_SimpleAddSettleWorkflow(self): def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel alice_channel, bob_channel = self.alice_channel, self.bob_channel
htlc = self.htlc htlc = self.htlc
alice_out = alice_channel.current_commitment(LOCAL).outputs() alice_out = alice_channel.get_latest_commitment(LOCAL).outputs()
short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42] short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42]
long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62] long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62]
self.assertLess(alice_out[long_idx].value, 5 * 10**8, alice_out) self.assertLess(alice_out[long_idx].value, 5 * 10**8, alice_out)
self.assertEqual(alice_out[short_idx].value, 5 * 10**8, alice_out) self.assertEqual(alice_out[short_idx].value, 5 * 10**8, alice_out)
alice_out = alice_channel.current_commitment(REMOTE).outputs() alice_out = alice_channel.get_latest_commitment(REMOTE).outputs()
short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42] short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42]
long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62] long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62]
self.assertLess(alice_out[short_idx].value, 5 * 10**8) self.assertLess(alice_out[short_idx].value, 5 * 10**8)
self.assertEqual(alice_out[long_idx].value, 5 * 10**8) self.assertEqual(alice_out[long_idx].value, 5 * 10**8)
def com(): self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
return alice_channel.local_commitment
self.assertTrue(alice_channel.signature_fits(com()))
self.assertNotEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), []) self.assertNotEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [])
@ -270,9 +271,9 @@ class TestChannel(unittest.TestCase):
from electrum.lnutil import extract_ctn_from_tx_and_chan from electrum.lnutil import extract_ctn_from_tx_and_chan
tx0 = str(alice_channel.force_close_tx()) tx0 = str(alice_channel.force_close_tx())
self.assertEqual(alice_channel.config[LOCAL].ctn, 0) self.assertEqual(alice_channel.get_oldest_unrevoked_ctn(LOCAL), 0)
self.assertEqual(extract_ctn_from_tx_and_chan(alice_channel.force_close_tx(), alice_channel), 0) self.assertEqual(extract_ctn_from_tx_and_chan(alice_channel.force_close_tx(), alice_channel), 0)
self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL))) self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
# Next alice commits this change by sending a signature message. Since # Next alice commits this change by sending a signature message. Since
# we expect the messages to be ordered, Bob will receive the HTLC we # we expect the messages to be ordered, Bob will receive the HTLC we
@ -281,21 +282,20 @@ class TestChannel(unittest.TestCase):
aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment() aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment()
self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature") self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature")
self.assertTrue(alice_channel.signature_fits(com())) self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(next(iter(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE)))[0], RECEIVED) self.assertEqual(next(iter(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE)))[0], RECEIVED)
self.assertEqual(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE), bob_channel.hm.get_htlcs_in_next_ctx(LOCAL)) self.assertEqual(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE), bob_channel.hm.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs()) self.assertEqual(alice_channel.get_latest_commitment(REMOTE).outputs(), bob_channel.get_next_commitment(LOCAL).outputs())
# Bob receives this signature message, and checks that this covers the # Bob receives this signature message, and checks that this covers the
# state he has in his remote log. This includes the HTLC just sent # state he has in his remote log. This includes the HTLC just sent
# from Alice. # from Alice.
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
bob_channel.receive_new_commitment(aliceSig, aliceHtlcSigs) bob_channel.receive_new_commitment(aliceSig, aliceHtlcSigs)
self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
self.assertEqual(bob_channel.config[REMOTE].ctn, 0) self.assertEqual(bob_channel.get_oldest_unrevoked_ctn(REMOTE), 0)
self.assertEqual(bob_channel.included_htlcs(LOCAL, RECEIVED, 1), [htlc])# self.assertEqual(bob_channel.included_htlcs(LOCAL, RECEIVED, 1), [htlc])#
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), []) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), [])
@ -311,57 +311,50 @@ class TestChannel(unittest.TestCase):
# has a valid signature for a newer commitment. # has a valid signature for a newer commitment.
bobRevocation, _ = bob_channel.revoke_current_commitment() bobRevocation, _ = bob_channel.revoke_current_commitment()
bob_channel.serialize() bob_channel.serialize()
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
# Bob finally sends a signature for Alice's commitment transaction. # Bob finally sends a signature for Alice's commitment transaction.
# This signature will cover the HTLC, since Bob will first send the # This signature will cover the HTLC, since Bob will first send the
# revocation just created. The revocation also acks every received # revocation just created. The revocation also acks every received
# HTLC up to the point where Alice sent her signature. # HTLC up to the point where Alice sent her signature.
bobSig, bobHtlcSigs = bob_channel.sign_next_commitment() bobSig, bobHtlcSigs = bob_channel.sign_next_commitment()
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
self.assertEqual(len(bobHtlcSigs), 1) self.assertEqual(len(bobHtlcSigs), 1)
self.assertTrue(alice_channel.signature_fits(com())) self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
# so far: Alice added htlc, Alice signed. # so far: Alice added htlc, Alice signed.
self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2) self.assertEqual(len(alice_channel.get_latest_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2) self.assertEqual(len(alice_channel.get_next_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 2) # oldest unrevoked self.assertEqual(len(alice_channel.get_oldest_unrevoked_commitment(REMOTE).outputs()), 2)
self.assertEqual(len(alice_channel.pending_commitment(REMOTE).outputs()), 3) # latest self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3)
# Alice then processes this revocation, sending her own revocation for # Alice then processes this revocation, sending her own revocation for
# her prior commitment transaction. Alice shouldn't have any HTLCs to # her prior commitment transaction. Alice shouldn't have any HTLCs to
# forward since she's sending an outgoing HTLC. # forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation) alice_channel.receive_revocation(bobRevocation)
alice_channel.serialize() alice_channel.serialize()
self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
self.assertTrue(alice_channel.signature_fits(com())) self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL)))
alice_channel.serialize() alice_channel.serialize()
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2) self.assertEqual(len(alice_channel.get_latest_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3) self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(com().outputs()), 2)
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2) self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
alice_channel.serialize() alice_channel.serialize()
self.assertEqual(alice_channel.pending_commitment(LOCAL).outputs(), self.assertEqual(alice_channel.get_next_commitment(LOCAL).outputs(),
bob_channel.pending_commitment(REMOTE).outputs()) bob_channel.get_latest_commitment(REMOTE).outputs())
# Alice then processes bob's signature, and since she just received # Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to # the revocation, she expect this signature to cover everything up to
# the point where she sent her signature, including the HTLC. # the point where she sent her signature, including the HTLC.
alice_channel.receive_new_commitment(bobSig, bobHtlcSigs) alice_channel.receive_new_commitment(bobSig, bobHtlcSigs)
self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3) self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(com().outputs()), 3)
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3) self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
@ -371,10 +364,8 @@ class TestChannel(unittest.TestCase):
self.assertNotEqual(tx0, tx1) self.assertNotEqual(tx0, tx1)
# Alice then generates a revocation for bob. # Alice then generates a revocation for bob.
self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
aliceRevocation, _ = alice_channel.revoke_current_commitment() aliceRevocation, _ = alice_channel.revoke_current_commitment()
alice_channel.serialize() alice_channel.serialize()
#self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
tx2 = str(alice_channel.force_close_tx()) tx2 = str(alice_channel.force_close_tx())
# since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one) # since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one)
@ -384,7 +375,7 @@ class TestChannel(unittest.TestCase):
# is fully locked in within both commitment transactions. Bob should # is fully locked in within both commitment transactions. Bob should
# also be able to forward an HTLC now that the HTLC has been locked # also be able to forward an HTLC now that the HTLC has been locked
# into both commitment transactions. # into both commitment transactions.
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
bob_channel.receive_revocation(aliceRevocation) bob_channel.receive_revocation(aliceRevocation)
bob_channel.serialize() bob_channel.serialize()
@ -398,13 +389,13 @@ class TestChannel(unittest.TestCase):
self.assertEqual(alice_channel.total_msat(RECEIVED), bobSent, "alice has incorrect milli-satoshis received") self.assertEqual(alice_channel.total_msat(RECEIVED), bobSent, "alice has incorrect milli-satoshis received")
self.assertEqual(bob_channel.total_msat(SENT), bobSent, "bob has incorrect milli-satoshis sent") self.assertEqual(bob_channel.total_msat(SENT), bobSent, "bob has incorrect milli-satoshis sent")
self.assertEqual(bob_channel.total_msat(RECEIVED), aliceSent, "bob has incorrect milli-satoshis received") self.assertEqual(bob_channel.total_msat(RECEIVED), aliceSent, "bob has incorrect milli-satoshis received")
self.assertEqual(bob_channel.config[LOCAL].ctn, 1, "bob has incorrect commitment height") self.assertEqual(bob_channel.get_oldest_unrevoked_ctn(LOCAL), 1, "bob has incorrect commitment height")
self.assertEqual(alice_channel.config[LOCAL].ctn, 1, "alice has incorrect commitment height") self.assertEqual(alice_channel.get_oldest_unrevoked_ctn(LOCAL), 1, "alice has incorrect commitment height")
# Both commitment transactions should have three outputs, and one of # Both commitment transactions should have three outputs, and one of
# them should be exactly the amount of the HTLC. # them should be exactly the amount of the HTLC.
alice_ctx = alice_channel.pending_commitment(LOCAL) alice_ctx = alice_channel.get_next_commitment(LOCAL)
bob_ctx = bob_channel.pending_commitment(LOCAL) bob_ctx = bob_channel.get_next_commitment(LOCAL)
self.assertEqual(len(alice_ctx.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_ctx.outputs())) self.assertEqual(len(alice_ctx.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_ctx.outputs()))
self.assertEqual(len(bob_ctx.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_ctx.outputs())) self.assertEqual(len(bob_ctx.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_ctx.outputs()))
self.assertOutputExistsByValue(alice_ctx, htlc.amount_msat // 1000) self.assertOutputExistsByValue(alice_ctx, htlc.amount_msat // 1000)
@ -415,7 +406,6 @@ class TestChannel(unittest.TestCase):
preimage = self.paymentPreimage preimage = self.paymentPreimage
bob_channel.settle_htlc(preimage, self.bobHtlcIndex) bob_channel.settle_htlc(preimage, self.bobHtlcIndex)
#self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
alice_channel.receive_htlc_settle(preimage, self.aliceHtlcIndex) alice_channel.receive_htlc_settle(preimage, self.aliceHtlcIndex)
tx3 = str(alice_channel.force_close_tx()) tx3 = str(alice_channel.force_close_tx())
@ -426,7 +416,7 @@ class TestChannel(unittest.TestCase):
self.assertEqual(len(bobHtlcSigs2), 0) self.assertEqual(len(bobHtlcSigs2), 0)
self.assertEqual(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED), [htlc]) self.assertEqual(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED), [htlc])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, alice_channel.config[REMOTE].ctn), [htlc]) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, alice_channel.get_oldest_unrevoked_ctn(REMOTE)), [htlc])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc]) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 2), [htlc]) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 2), [htlc])
@ -440,8 +430,8 @@ class TestChannel(unittest.TestCase):
self.assertEqual(bob_channel.included_htlcs(REMOTE, RECEIVED, 1), []) self.assertEqual(bob_channel.included_htlcs(REMOTE, RECEIVED, 1), [])
self.assertEqual(bob_channel.included_htlcs(REMOTE, RECEIVED, 2), []) self.assertEqual(bob_channel.included_htlcs(REMOTE, RECEIVED, 2), [])
alice_ctx_bob_version = bob_channel.pending_commitment(REMOTE).outputs() alice_ctx_bob_version = bob_channel.get_latest_commitment(REMOTE).outputs()
alice_ctx_alice_version = alice_channel.pending_commitment(LOCAL).outputs() alice_ctx_alice_version = alice_channel.get_next_commitment(LOCAL).outputs()
self.assertEqual(alice_ctx_alice_version, alice_ctx_bob_version) self.assertEqual(alice_ctx_alice_version, alice_ctx_bob_version)
alice_channel.receive_new_commitment(bobSig2, bobHtlcSigs2) alice_channel.receive_new_commitment(bobSig2, bobHtlcSigs2)
@ -450,14 +440,13 @@ class TestChannel(unittest.TestCase):
self.assertNotEqual(tx3, tx4) self.assertNotEqual(tx3, tx4)
self.assertEqual(alice_channel.balance(LOCAL), 500000000000) self.assertEqual(alice_channel.balance(LOCAL), 500000000000)
self.assertEqual(1, alice_channel.config[LOCAL].ctn) self.assertEqual(1, alice_channel.get_oldest_unrevoked_ctn(LOCAL))
self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0) self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0)
aliceRevocation2, _ = alice_channel.revoke_current_commitment() aliceRevocation2, _ = alice_channel.revoke_current_commitment()
alice_channel.serialize() alice_channel.serialize()
aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment() aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment()
self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures") self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures")
self.assertEqual(len(bob_channel.current_commitment(LOCAL).outputs()), 3) self.assertEqual(len(bob_channel.get_latest_commitment(LOCAL).outputs()), 3)
#self.assertEqual(len(bob_channel.pending_commitment(LOCAL).outputs()), 3)
bob_channel.receive_revocation(aliceRevocation2) bob_channel.receive_revocation(aliceRevocation2)
bob_channel.serialize() bob_channel.serialize()
@ -478,14 +467,14 @@ class TestChannel(unittest.TestCase):
self.assertEqual(alice_channel.total_msat(RECEIVED), 0, "alice satoshis received incorrect") self.assertEqual(alice_channel.total_msat(RECEIVED), 0, "alice satoshis received incorrect")
self.assertEqual(bob_channel.total_msat(RECEIVED), mSatTransferred, "bob satoshis received incorrect") self.assertEqual(bob_channel.total_msat(RECEIVED), mSatTransferred, "bob satoshis received incorrect")
self.assertEqual(bob_channel.total_msat(SENT), 0, "bob satoshis sent incorrect") self.assertEqual(bob_channel.total_msat(SENT), 0, "bob satoshis sent incorrect")
self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height") self.assertEqual(bob_channel.get_latest_ctn(LOCAL), 2, "bob has incorrect commitment height")
self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height") self.assertEqual(alice_channel.get_latest_ctn(LOCAL), 2, "alice has incorrect commitment height")
alice_channel.update_fee(100000, True) alice_channel.update_fee(100000, True)
alice_outputs = alice_channel.pending_commitment(REMOTE).outputs() alice_outputs = alice_channel.get_next_commitment(REMOTE).outputs()
old_outputs = bob_channel.pending_commitment(LOCAL).outputs() old_outputs = bob_channel.get_next_commitment(LOCAL).outputs()
bob_channel.update_fee(100000, False) bob_channel.update_fee(100000, False)
new_outputs = bob_channel.pending_commitment(LOCAL).outputs() new_outputs = bob_channel.get_next_commitment(LOCAL).outputs()
self.assertNotEqual(old_outputs, new_outputs) self.assertNotEqual(old_outputs, new_outputs)
self.assertEqual(alice_outputs, new_outputs) self.assertEqual(alice_outputs, new_outputs)
@ -517,13 +506,13 @@ class TestChannel(unittest.TestCase):
def alice_to_bob_fee_update(self, fee=111): def alice_to_bob_fee_update(self, fee=111):
aoldctx = self.alice_channel.pending_commitment(REMOTE).outputs() aoldctx = self.alice_channel.get_next_commitment(REMOTE).outputs()
self.alice_channel.update_fee(fee, True) self.alice_channel.update_fee(fee, True)
anewctx = self.alice_channel.pending_commitment(REMOTE).outputs() anewctx = self.alice_channel.get_next_commitment(REMOTE).outputs()
self.assertNotEqual(aoldctx, anewctx) self.assertNotEqual(aoldctx, anewctx)
boldctx = self.bob_channel.pending_commitment(LOCAL).outputs() boldctx = self.bob_channel.get_next_commitment(LOCAL).outputs()
self.bob_channel.update_fee(fee, False) self.bob_channel.update_fee(fee, False)
bnewctx = self.bob_channel.pending_commitment(LOCAL).outputs() bnewctx = self.bob_channel.get_next_commitment(LOCAL).outputs()
self.assertNotEqual(boldctx, bnewctx) self.assertNotEqual(boldctx, bnewctx)
self.assertEqual(anewctx, bnewctx) self.assertEqual(anewctx, bnewctx)
return fee return fee
@ -805,12 +794,12 @@ class TestDust(unittest.TestCase):
'timestamp' : 0, 'timestamp' : 0,
} }
old_values = [x.value for x in bob_channel.current_commitment(LOCAL).outputs() ] old_values = [x.value for x in bob_channel.get_latest_commitment(LOCAL).outputs() ]
aliceHtlcIndex = alice_channel.add_htlc(htlc).htlc_id aliceHtlcIndex = alice_channel.add_htlc(htlc).htlc_id
bobHtlcIndex = bob_channel.receive_htlc(htlc).htlc_id bobHtlcIndex = bob_channel.receive_htlc(htlc).htlc_id
force_state_transition(alice_channel, bob_channel) force_state_transition(alice_channel, bob_channel)
alice_ctx = alice_channel.current_commitment(LOCAL) alice_ctx = alice_channel.get_latest_commitment(LOCAL)
bob_ctx = bob_channel.current_commitment(LOCAL) bob_ctx = bob_channel.get_latest_commitment(LOCAL)
new_values = [x.value for x in bob_ctx.outputs() ] new_values = [x.value for x in bob_ctx.outputs() ]
self.assertNotEqual(old_values, new_values) self.assertNotEqual(old_values, new_values)
self.assertEqual(len(alice_ctx.outputs()), 3) self.assertEqual(len(alice_ctx.outputs()), 3)
@ -820,7 +809,7 @@ class TestDust(unittest.TestCase):
bob_channel.settle_htlc(paymentPreimage, bobHtlcIndex) bob_channel.settle_htlc(paymentPreimage, bobHtlcIndex)
alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex) alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex)
force_state_transition(bob_channel, alice_channel) force_state_transition(bob_channel, alice_channel)
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2) self.assertEqual(len(alice_channel.get_next_commitment(LOCAL).outputs()), 2)
self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt) self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt)
def force_state_transition(chanA, chanB): def force_state_transition(chanA, chanB):

Loading…
Cancel
Save