diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index bf7a3d199..e71482642 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -28,6 +28,7 @@ class HTLCManager: log[LOCAL] = deepcopy(initial) log[REMOTE] = deepcopy(initial) log[LOCAL]['unacked_updates'] = {} + log[LOCAL]['was_revoke_last'] = False # maybe bootstrap fee_updates if initial_feerate was provided if initial_feerate is not None: @@ -155,6 +156,7 @@ class HTLCManager: def send_ctx(self) -> None: assert self.ctn_latest(REMOTE) == self.ctn_oldest_unrevoked(REMOTE), (self.ctn_latest(REMOTE), self.ctn_oldest_unrevoked(REMOTE)) self._set_revack_pending(REMOTE, True) + self.log[LOCAL]['was_revoke_last'] = False @with_lock def recv_ctx(self) -> None: @@ -165,6 +167,7 @@ class HTLCManager: def send_rev(self) -> None: self.log[LOCAL]['ctn'] += 1 self._set_revack_pending(LOCAL, False) + self.log[LOCAL]['was_revoke_last'] = True # htlcs for htlc_id in self._maybe_active_htlc_ids[REMOTE]: ctns = self.log[REMOTE]['locked_in'][htlc_id] @@ -287,6 +290,11 @@ class HTLCManager: return {ctn: [bfh(msg) for msg in messages] for ctn, messages in self.log[LOCAL]['unacked_updates'].items()} + @with_lock + def was_revoke_last(self) -> bool: + """Whether we sent a revoke_and_ack after the last commitment_signed we sent.""" + return self.log[LOCAL].get('was_revoke_last') or False + ##### Queries re HTLCs: def get_htlc_by_id(self, htlc_proposer: HTLCOwner, htlc_id: int) -> UpdateAddHtlc: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index a5678574e..0366fcdc5 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -115,6 +115,7 @@ class Peer(Logger): self._htlc_switch_iterstart_event = asyncio.Event() self._htlc_switch_iterdone_event = asyncio.Event() self._received_revack_event = asyncio.Event() + self.received_commitsig_event = asyncio.Event() self.downstream_htlc_resolved_event = asyncio.Event() def send_message(self, message_name: str, **kwargs): @@ -1221,25 +1222,28 @@ class Peer(Logger): await fut we_must_resend_revoke_and_ack, their_next_local_ctn = fut.result() - # Replay un-acked local updates (including commitment_signed) byte-for-byte. - # If we have sent them a commitment signature that they "lost" (due to disconnect), - # we need to make sure we replay the same local updates, as otherwise they could - # end up with two (or more) signed valid commitment transactions at the same ctn. - # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns, - # e.g. for watchtowers, hence we must ensure these ctxs coincide. - # We replay the local updates even if they were not yet committed. - unacked = chan.hm.get_unacked_local_updates() - n_replayed_msgs = 0 - for ctn, messages in unacked.items(): - if ctn < their_next_local_ctn: - # They claim to have received these messages and the corresponding - # commitment_signed, hence we must not replay them. - continue - for raw_upd_msg in messages: - self.transport.send_bytes(raw_upd_msg) - n_replayed_msgs += 1 - self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): replayed {n_replayed_msgs} unacked messages') - if we_must_resend_revoke_and_ack: + def replay_updates_and_commitsig(): + # Replay un-acked local updates (including commitment_signed) byte-for-byte. + # If we have sent them a commitment signature that they "lost" (due to disconnect), + # we need to make sure we replay the same local updates, as otherwise they could + # end up with two (or more) signed valid commitment transactions at the same ctn. + # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns, + # e.g. for watchtowers, hence we must ensure these ctxs coincide. + # We replay the local updates even if they were not yet committed. + unacked = chan.hm.get_unacked_local_updates() + replayed_msgs = [] + for ctn, messages in unacked.items(): + if ctn < their_next_local_ctn: + # They claim to have received these messages and the corresponding + # commitment_signed, hence we must not replay them. + continue + for raw_upd_msg in messages: + self.transport.send_bytes(raw_upd_msg) + replayed_msgs.append(raw_upd_msg) + self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): replayed {len(replayed_msgs)} unacked messages. ' + f'{[decode_msg(raw_upd_msg)[0] for raw_upd_msg in replayed_msgs]}') + + def resend_revoke_and_ack(): last_secret, last_point = chan.get_secret_and_point(LOCAL, oldest_unrevoked_local_ctn - 1) next_secret, next_point = chan.get_secret_and_point(LOCAL, oldest_unrevoked_local_ctn + 1) self.send_message( @@ -1247,6 +1251,21 @@ class Peer(Logger): channel_id=chan.channel_id, per_commitment_secret=last_secret, next_per_commitment_point=next_point) + + # We need to preserve relative order of last revack and commitsig. + # note: it is not possible to recover and reestablish a channel if we are out-of-sync by + # more than one ctns, i.e. we will only ever retransmit up to one commitment_signed message. + # Hence, if we need to retransmit a revack, without loss of generality, we can either replay + # it as the first message or as the last message. + was_revoke_last = chan.hm.was_revoke_last() + if we_must_resend_revoke_and_ack and not was_revoke_last: + self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): replaying a revoke_and_ack first.') + resend_revoke_and_ack() + replay_updates_and_commitsig() + if we_must_resend_revoke_and_ack and was_revoke_last: + self.logger.info(f'channel_reestablish ({chan.get_id_for_log()}): replaying a revoke_and_ack last.') + resend_revoke_and_ack() + chan.peer_state = PeerState.GOOD if chan.is_funded() and their_next_local_ctn == next_local_ctn == 1: self.send_funding_locked(chan) @@ -1478,6 +1497,8 @@ class Peer(Logger): htlc_sigs = list(chunks(data, 64)) chan.receive_new_commitment(payload["signature"], htlc_sigs) self.send_revoke_and_ack(chan) + self.received_commitsig_event.set() + self.received_commitsig_event.clear() def on_update_fulfill_htlc(self, chan: Channel, payload): preimage = payload["payment_preimage"] diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 2aceae9ff..23f367bf2 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -35,7 +35,7 @@ from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED from electrum.lnonion import OnionFailureCode -from electrum.lnutil import derive_payment_secret_from_payment_preimage +from electrum.lnutil import derive_payment_secret_from_payment_preimage, UpdateAddHtlc from electrum.lnutil import LOCAL, REMOTE from electrum.invoices import PR_PAID, PR_UNPAID from electrum.interface import GracefulDisconnect @@ -256,7 +256,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): class MockTransport: def __init__(self, name): - self.queue = asyncio.Queue() + self.queue = asyncio.Queue() # incoming messages self._name = name self.peer_addr = None @@ -265,7 +265,11 @@ class MockTransport: async def read_messages(self): while True: - yield await self.queue.get() + data = await self.queue.get() + if isinstance(data, asyncio.Event): # to artificially delay messages + await data.wait() + continue + yield data class NoFeaturesTransport(MockTransport): """ @@ -382,8 +386,14 @@ class TestPeer(TestCaseForTestnet): super().tearDown() - def prepare_peers(self, alice_channel: Channel, bob_channel: Channel): - k1, k2 = keypair(), keypair() + def prepare_peers( + self, alice_channel: Channel, bob_channel: Channel, + *, k1: Keypair = None, k2: Keypair = None, + ): + if k1 is None: + k1 = keypair() + if k2 is None: + k2 = keypair() alice_channel.node_id = k2.pubkey bob_channel.node_id = k1.pubkey t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) @@ -557,6 +567,130 @@ class TestPeer(TestCaseForTestnet): self.assertEqual(alice_channel_0.peer_state, PeerState.BAD) self.assertEqual(bob_channel._state, ChannelState.FORCE_CLOSING) + @staticmethod + def _send_fake_htlc(peer: Peer, chan: Channel) -> UpdateAddHtlc: + htlc = UpdateAddHtlc(amount_msat=10000, payment_hash=os.urandom(32), cltv_expiry=999, timestamp=1) + htlc = chan.add_htlc(htlc) + peer.send_message( + "update_add_htlc", + channel_id=chan.channel_id, + id=htlc.htlc_id, + cltv_expiry=htlc.cltv_expiry, + amount_msat=htlc.amount_msat, + payment_hash=htlc.payment_hash, + onion_routing_packet=1366 * b"0", + ) + return htlc + + def test_reestablish_replay_messages_rev_then_sig(self): + """ + See https://github.com/lightning/bolts/pull/810#issue-728299277 + + Rev then Sig + A B + <---add----- + ----add----> + <---sig----- + ----rev----x + ----sig----x + + A needs to retransmit: + ----rev--> (note that 'add' can be first too) + ----add--> + ----sig--> + """ + chan_AB, chan_BA = create_test_channels() + k1, k2 = keypair(), keypair() + # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. + async def f(): + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p2._message_loop()) + await p1.initialized + await p2.initialized + self._send_fake_htlc(p2, chan_BA) + self._send_fake_htlc(p1, chan_AB) + p2.transport.queue.put_nowait(asyncio.Event()) # break Bob's incoming pipe + self.assertTrue(p2.maybe_send_commitment(chan_BA)) + await p1.received_commitsig_event.wait() + await group.cancel_remaining() + # simulating disconnection. recreate transports. + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + for chan in (chan_AB, chan_BA): + chan.peer_state = PeerState.DISCONNECTED + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p2._message_loop()) + with self.assertLogs('electrum', level='INFO') as logs: + async with OldTaskGroup() as group2: + await group2.spawn(p1.reestablish_channel(chan_AB)) + await group2.spawn(p2.reestablish_channel(chan_BA)) + self.assertTrue(any(("alice->bob" in msg and + "replaying a revoke_and_ack first" in msg) for msg in logs.output)) + self.assertTrue(any(("alice->bob" in msg and + "replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output)) + self.assertEqual(chan_AB.peer_state, PeerState.GOOD) + self.assertEqual(chan_BA.peer_state, PeerState.GOOD) + raise SuccessfulTest() + with self.assertRaises(SuccessfulTest): + run(f()) + + def test_reestablish_replay_messages_sig_then_rev(self): + """ + See https://github.com/lightning/bolts/pull/810#issue-728299277 + + Sig then Rev + A B + <---add----- + ----add----> + ----sig----x + <---sig----- + ----rev----x + + A needs to retransmit: + ----add--> + ----sig--> + ----rev--> + """ + chan_AB, chan_BA = create_test_channels() + k1, k2 = keypair(), keypair() + # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. + async def f(): + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p2._message_loop()) + await p1.initialized + await p2.initialized + self._send_fake_htlc(p2, chan_BA) + self._send_fake_htlc(p1, chan_AB) + p2.transport.queue.put_nowait(asyncio.Event()) # break Bob's incoming pipe + self.assertTrue(p1.maybe_send_commitment(chan_AB)) + self.assertTrue(p2.maybe_send_commitment(chan_BA)) + await p1.received_commitsig_event.wait() + await group.cancel_remaining() + # simulating disconnection. recreate transports. + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + for chan in (chan_AB, chan_BA): + chan.peer_state = PeerState.DISCONNECTED + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p2._message_loop()) + with self.assertLogs('electrum', level='INFO') as logs: + async with OldTaskGroup() as group2: + await group2.spawn(p1.reestablish_channel(chan_AB)) + await group2.spawn(p2.reestablish_channel(chan_BA)) + self.assertTrue(any(("alice->bob" in msg and + "replaying a revoke_and_ack last" in msg) for msg in logs.output)) + self.assertTrue(any(("alice->bob" in msg and + "replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output)) + self.assertEqual(chan_AB.peer_state, PeerState.GOOD) + self.assertEqual(chan_BA.peer_state, PeerState.GOOD) + raise SuccessfulTest() + with self.assertRaises(SuccessfulTest): + run(f()) + @needs_test_with_all_chacha20_implementations def test_payment(self): """Alice pays Bob a single HTLC via direct channel."""