diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 20e45b9d4..634e40e33 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -65,7 +65,7 @@ class Peer(Logger): LOGGING_SHORTCUT = 'P' ORDERED_MESSAGES = ( - 'accept_channel', 'funding_signed', 'funding_created', 'accept_channel', 'channel_reestablish', 'closing_signed') + 'accept_channel', 'funding_signed', 'funding_created', 'accept_channel', 'closing_signed') SPAMMY_MESSAGES = ( 'ping', 'pong', 'channel_announcement', 'node_announcement', 'channel_update',) @@ -103,6 +103,7 @@ class Peer(Logger): self.funding_signed_sent = set() # for channels in PREOPENING self.shutdown_received = {} # chan_id -> asyncio.Future() self.announcement_signatures = defaultdict(asyncio.Queue) + self.channel_reestablish_msg = defaultdict(asyncio.Queue) self.orphan_channel_updates = OrderedDict() # type: OrderedDict[ShortChannelID, dict] Logger.__init__(self) self.taskgroup = OldTaskGroup() @@ -907,103 +908,40 @@ class Peer(Logger): your_last_per_commitment_secret=0, my_current_per_commitment_point=latest_point) - async def reestablish_channel(self, chan: Channel): - await self.initialized - chan_id = chan.channel_id - if chan.should_request_force_close: - await self.trigger_force_close(chan_id) - chan.should_request_force_close = False - return - assert ChannelState.PREOPENING < chan.get_state() < ChannelState.FORCE_CLOSING - if chan.peer_state != PeerState.DISCONNECTED: - self.logger.info(f'reestablish_channel was called but channel {chan.get_id_for_log()} ' - f'already in peer_state {chan.peer_state!r}') - return - chan.peer_state = PeerState.REESTABLISHING - util.trigger_callback('channel', self.lnworker.wallet, chan) - # BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side" - chan.hm.discard_unsigned_remote_updates() - # ctns - oldest_unrevoked_local_ctn = chan.get_oldest_unrevoked_ctn(LOCAL) - latest_local_ctn = chan.get_latest_ctn(LOCAL) - next_local_ctn = chan.get_next_ctn(LOCAL) - oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) - latest_remote_ctn = chan.get_latest_ctn(REMOTE) - next_remote_ctn = chan.get_next_ctn(REMOTE) - assert self.features.supports(LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT) - # send message - if chan.is_static_remotekey_enabled(): - latest_secret, latest_point = chan.get_secret_and_point(LOCAL, 0) - else: - latest_secret, latest_point = chan.get_secret_and_point(LOCAL, latest_local_ctn) - if oldest_unrevoked_remote_ctn == 0: - last_rev_secret = 0 - else: - last_rev_index = oldest_unrevoked_remote_ctn - 1 - last_rev_secret = chan.revocation_store.retrieve_secret(RevocationStore.START_INDEX - last_rev_index) - self.send_message( - "channel_reestablish", - channel_id=chan_id, - next_commitment_number=next_local_ctn, - next_revocation_number=oldest_unrevoked_remote_ctn, - your_last_per_commitment_secret=last_rev_secret, - my_current_per_commitment_point=latest_point) - self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): sent channel_reestablish with ' - f'(next_local_ctn={next_local_ctn}, ' - f'oldest_unrevoked_remote_ctn={oldest_unrevoked_remote_ctn})') - while True: - try: - msg = await self.wait_for_message('channel_reestablish', chan_id) - break - except asyncio.TimeoutError: - self.logger.info('waiting to receive channel_reestablish...') - continue - except Exception as e: - # do not kill connection, because we might have other channels with that peer - self.logger.info(f'channel_reestablish failed, {e}') - return + def on_channel_reestablish(self, chan, msg): their_next_local_ctn = msg["next_commitment_number"] their_oldest_unrevoked_remote_ctn = msg["next_revocation_number"] their_local_pcp = msg.get("my_current_per_commitment_point") their_claim_of_our_last_per_commitment_secret = msg.get("your_last_per_commitment_secret") - self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): received channel_reestablish with ' - f'(their_next_local_ctn={their_next_local_ctn}, ' - f'their_oldest_unrevoked_remote_ctn={their_oldest_unrevoked_remote_ctn})') + self.logger.info( + f'channel_reestablish ({chan.get_id_for_log()}): received channel_reestablish with ' + f'(their_next_local_ctn={their_next_local_ctn}, ' + f'their_oldest_unrevoked_remote_ctn={their_oldest_unrevoked_remote_ctn})') # sanity checks of received values if their_next_local_ctn < 0: raise RemoteMisbehaving(f"channel reestablish: their_next_local_ctn < 0") if their_oldest_unrevoked_remote_ctn < 0: raise RemoteMisbehaving(f"channel reestablish: their_oldest_unrevoked_remote_ctn < 0") - # Replay un-acked local updates (including commitment_signed) byte-for-byte. - # If we have sent them a commitment signature that they "lost" (due to disconnect), - # we need to make sure we replay the same local updates, as otherwise they could - # end up with two (or more) signed valid commitment transactions at the same ctn. - # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns, - # e.g. for watchtowers, hence we must ensure these ctxs coincide. - # We replay the local updates even if they were not yet committed. - unacked = chan.hm.get_unacked_local_updates() - n_replayed_msgs = 0 - for ctn, messages in unacked.items(): - if ctn < their_next_local_ctn: - # They claim to have received these messages and the corresponding - # commitment_signed, hence we must not replay them. - continue - for raw_upd_msg in messages: - self.transport.send_bytes(raw_upd_msg) - n_replayed_msgs += 1 - self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): replayed {n_replayed_msgs} unacked messages') - + # ctns + oldest_unrevoked_local_ctn = chan.get_oldest_unrevoked_ctn(LOCAL) + latest_local_ctn = chan.get_latest_ctn(LOCAL) + next_local_ctn = chan.get_next_ctn(LOCAL) + oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) + latest_remote_ctn = chan.get_latest_ctn(REMOTE) + next_remote_ctn = chan.get_next_ctn(REMOTE) + # compare remote ctns we_are_ahead = False they_are_ahead = False - # compare remote ctns + we_must_resend_revoke_and_ack = False if next_remote_ctn != their_next_local_ctn: if their_next_local_ctn == latest_remote_ctn and chan.hm.is_revack_pending(REMOTE): - # We replayed the local updates (see above), which should have contained a commitment_signed - # (due to is_revack_pending being true), and this should have remedied this situation. + # We will replay the local updates (see reestablish_channel), which should contain a commitment_signed + # (due to is_revack_pending being true), and this should remedy this situation. pass else: - self.logger.warning(f"channel_reestablish ({chan.get_id_for_log()}): " - f"expected remote ctn {next_remote_ctn}, got {their_next_local_ctn}") + self.logger.warning( + f"channel_reestablish ({chan.get_id_for_log()}): " + f"expected remote ctn {next_remote_ctn}, got {their_next_local_ctn}") if their_next_local_ctn < next_remote_ctn: we_are_ahead = True else: @@ -1011,25 +949,17 @@ class Peer(Logger): # compare local ctns if oldest_unrevoked_local_ctn != their_oldest_unrevoked_remote_ctn: if oldest_unrevoked_local_ctn - 1 == their_oldest_unrevoked_remote_ctn: - # A node: - # if next_revocation_number is equal to the commitment number of the last revoke_and_ack - # the receiving node sent, AND the receiving node hasn't already received a closing_signed: - # MUST re-send the revoke_and_ack. - last_secret, last_point = chan.get_secret_and_point(LOCAL, oldest_unrevoked_local_ctn - 1) - next_secret, next_point = chan.get_secret_and_point(LOCAL, oldest_unrevoked_local_ctn + 1) - self.send_message( - "revoke_and_ack", - channel_id=chan.channel_id, - per_commitment_secret=last_secret, - next_per_commitment_point=next_point) + we_must_resend_revoke_and_ack = True else: - self.logger.warning(f"channel_reestablish ({chan.get_id_for_log()}): " - f"expected local ctn {oldest_unrevoked_local_ctn}, got {their_oldest_unrevoked_remote_ctn}") + self.logger.warning( + f"channel_reestablish ({chan.get_id_for_log()}): " + f"expected local ctn {oldest_unrevoked_local_ctn}, got {their_oldest_unrevoked_remote_ctn}") if their_oldest_unrevoked_remote_ctn < oldest_unrevoked_local_ctn: we_are_ahead = True else: they_are_ahead = True # option_data_loss_protect + assert self.features.supports(LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT) def are_datalossprotect_fields_valid() -> bool: if their_local_pcp is None or their_claim_of_our_last_per_commitment_secret is None: return False @@ -1039,8 +969,9 @@ class Peer(Logger): assert their_oldest_unrevoked_remote_ctn == 0 our_pcs = bytes(32) if our_pcs != their_claim_of_our_last_per_commitment_secret: - self.logger.error(f"channel_reestablish ({chan.get_id_for_log()}): " - f"(DLP) local PCS mismatch: {bh2u(our_pcs)} != {bh2u(their_claim_of_our_last_per_commitment_secret)}") + self.logger.error( + f"channel_reestablish ({chan.get_id_for_log()}): " + f"(DLP) local PCS mismatch: {bh2u(our_pcs)} != {bh2u(their_claim_of_our_last_per_commitment_secret)}") return False if chan.is_static_remotekey_enabled(): return True @@ -1050,27 +981,108 @@ class Peer(Logger): pass else: if our_remote_pcp != their_local_pcp: - self.logger.error(f"channel_reestablish ({chan.get_id_for_log()}): " - f"(DLP) remote PCP mismatch: {bh2u(our_remote_pcp)} != {bh2u(their_local_pcp)}") + self.logger.error( + f"channel_reestablish ({chan.get_id_for_log()}): " + f"(DLP) remote PCP mismatch: {bh2u(our_remote_pcp)} != {bh2u(their_local_pcp)}") return False return True - if not are_datalossprotect_fields_valid(): raise RemoteMisbehaving("channel_reestablish: data loss protect fields invalid") - if they_are_ahead: - self.logger.warning(f"channel_reestablish ({chan.get_id_for_log()}): " - f"remote is ahead of us! They should force-close. Remote PCP: {bh2u(their_local_pcp)}") + self.logger.warning( + f"channel_reestablish ({chan.get_id_for_log()}): " + f"remote is ahead of us! They should force-close. Remote PCP: {bh2u(their_local_pcp)}") # data_loss_protect_remote_pcp is used in lnsweep chan.set_data_loss_protect_remote_pcp(their_next_local_ctn - 1, their_local_pcp) self.lnworker.save_channel(chan) chan.peer_state = PeerState.BAD - return - elif we_are_ahead: + raise RemoteMisbehaving("remote ahead of us") + if we_are_ahead: self.logger.warning(f"channel_reestablish ({chan.get_id_for_log()}): we are ahead of remote! trying to force-close.") - await self.lnworker.try_force_closing(chan_id) + asyncio.ensure_future(self.lnworker.try_force_closing(chan.channel_id)) + raise RemoteMisbehaving("we are ahead of remote") + # all good + self.channel_reestablish_msg[chan.channel_id].put_nowait((we_must_resend_revoke_and_ack, their_next_local_ctn)) + + async def reestablish_channel(self, chan: Channel): + await self.initialized + chan_id = chan.channel_id + if chan.should_request_force_close: + await self.trigger_force_close(chan_id) + chan.should_request_force_close = False return + assert ChannelState.PREOPENING < chan.get_state() < ChannelState.FORCE_CLOSING + if chan.peer_state != PeerState.DISCONNECTED: + self.logger.info( + f'reestablish_channel was called but channel {chan.get_id_for_log()} ' + f'already in peer_state {chan.peer_state!r}') + return + chan.peer_state = PeerState.REESTABLISHING + util.trigger_callback('channel', self.lnworker.wallet, chan) + # ctns + oldest_unrevoked_local_ctn = chan.get_oldest_unrevoked_ctn(LOCAL) + latest_local_ctn = chan.get_latest_ctn(LOCAL) + next_local_ctn = chan.get_next_ctn(LOCAL) + oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) + latest_remote_ctn = chan.get_latest_ctn(REMOTE) + next_remote_ctn = chan.get_next_ctn(REMOTE) + # BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side" + chan.hm.discard_unsigned_remote_updates() + # send message + if chan.is_static_remotekey_enabled(): + latest_secret, latest_point = chan.get_secret_and_point(LOCAL, 0) + else: + latest_secret, latest_point = chan.get_secret_and_point(LOCAL, latest_local_ctn) + if oldest_unrevoked_remote_ctn == 0: + last_rev_secret = 0 + else: + last_rev_index = oldest_unrevoked_remote_ctn - 1 + last_rev_secret = chan.revocation_store.retrieve_secret(RevocationStore.START_INDEX - last_rev_index) + self.send_message( + "channel_reestablish", + channel_id=chan_id, + next_commitment_number=next_local_ctn, + next_revocation_number=oldest_unrevoked_remote_ctn, + your_last_per_commitment_secret=last_rev_secret, + my_current_per_commitment_point=latest_point) + self.logger.info( + f'channel_reestablish ({chan.get_id_for_log()}): sent channel_reestablish with ' + f'(next_local_ctn={next_local_ctn}, ' + f'oldest_unrevoked_remote_ctn={oldest_unrevoked_remote_ctn})') + + # wait until we receive their channel_reestablish + we_must_resend_revoke_and_ack, their_next_local_ctn = await self.channel_reestablish_msg[chan_id].get() + # Replay un-acked local updates (including commitment_signed) byte-for-byte. + # If we have sent them a commitment signature that they "lost" (due to disconnect), + # we need to make sure we replay the same local updates, as otherwise they could + # end up with two (or more) signed valid commitment transactions at the same ctn. + # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns, + # e.g. for watchtowers, hence we must ensure these ctxs coincide. + # We replay the local updates even if they were not yet committed. + unacked = chan.hm.get_unacked_local_updates() + n_replayed_msgs = 0 + for ctn, messages in unacked.items(): + if ctn < their_next_local_ctn: + # They claim to have received these messages and the corresponding + # commitment_signed, hence we must not replay them. + continue + for raw_upd_msg in messages: + self.transport.send_bytes(raw_upd_msg) + n_replayed_msgs += 1 + self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): replayed {n_replayed_msgs} unacked messages') + # A node: + # if next_revocation_number is equal to the commitment number of the last revoke_and_ack + # the receiving node sent, AND the receiving node hasn't already received a closing_signed: + # MUST re-send the revoke_and_ack. + if we_must_resend_revoke_and_ack: + last_secret, last_point = chan.get_secret_and_point(LOCAL, oldest_unrevoked_local_ctn - 1) + next_secret, next_point = chan.get_secret_and_point(LOCAL, oldest_unrevoked_local_ctn + 1) + self.send_message( + "revoke_and_ack", + channel_id=chan.channel_id, + per_commitment_secret=last_secret, + next_per_commitment_point=next_point) chan.peer_state = PeerState.GOOD if chan.is_funded() and their_next_local_ctn == next_local_ctn == 1: self.send_funding_locked(chan) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 9a5f1d540..052f695f8 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -2283,16 +2283,10 @@ class LNWallet(LNWorker): peer_addr = LNPeerAddr(host, port, node_id) transport = LNTransport(privkey, peer_addr, proxy=self.network.proxy) peer = Peer(self, node_id, transport, is_channel_backup=True) - async def trigger_force_close_and_wait(): - # wait before closing the transport, so that - # remote has time to process the message. - await peer.trigger_force_close(channel_id) - await asyncio.sleep(1) - peer.transport.close() try: async with OldTaskGroup(wait=any) as group: await group.spawn(peer._message_loop()) - await group.spawn(trigger_force_close_and_wait()) + await group.spawn(peer.trigger_force_close(channel_id)) return except Exception as e: self.logger.info(f'failed to connect {host} {e}') diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index a3667342d..7d1f5df81 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -540,7 +540,6 @@ class TestPeer(TestCaseForTestnet): await gath with self.assertRaises(concurrent.futures.CancelledError): run(f()) - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel_0, bob_channel) for chan in (alice_channel_0, bob_channel): chan.peer_state = PeerState.DISCONNECTED @@ -548,16 +547,13 @@ class TestPeer(TestCaseForTestnet): await asyncio.gather( p1.reestablish_channel(alice_channel_0), p2.reestablish_channel(bob_channel)) - self.assertEqual(alice_channel_0.peer_state, PeerState.BAD) - self.assertEqual(bob_channel._state, ChannelState.FORCE_CLOSING) - # wait so that pending messages are processed - #await asyncio.sleep(1) - gath.cancel() gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) async def f(): await gath - with self.assertRaises(concurrent.futures.CancelledError): + with self.assertRaises(electrum.lnutil.RemoteMisbehaving): run(f()) + self.assertEqual(alice_channel_0.peer_state, PeerState.BAD) + self.assertEqual(bob_channel._state, ChannelState.FORCE_CLOSING) @needs_test_with_all_chacha20_implementations def test_payment(self):