diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 8d32372c5..307a1f7ff 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -98,7 +98,7 @@ class Peer(Logger): self.reply_channel_range = asyncio.Queue() # gossip uses a single queue to preserve message order self.gossip_queue = asyncio.Queue() - self.ordered_message_queues = defaultdict(asyncio.Queue) # for messsage that are ordered + self.ordered_message_queues = defaultdict(asyncio.Queue) # for messages that are ordered self.temp_id_to_id = {} # to forward error messages self.funding_created_sent = set() # for channels in PREOPENING self.funding_signed_sent = set() # for channels in PREOPENING @@ -205,7 +205,7 @@ class Peer(Logger): chan_id = payload.get('channel_id') or payload["temporary_channel_id"] self.ordered_message_queues[chan_id].put_nowait((message_type, payload)) else: - if message_type != 'error' and 'channel_id' in payload: + if message_type not in ('error', 'warning') and 'channel_id' in payload: chan = self.get_channel_by_id(payload['channel_id']) if chan is None: raise Exception('Got unknown '+ message_type) @@ -224,12 +224,96 @@ class Peer(Logger): if asyncio.iscoroutinefunction(f): asyncio.ensure_future(self.taskgroup.spawn(execution_result)) + def on_warning(self, payload): + # TODO: we could need some reconnection logic here -> delayed reconnect + self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}") + channel_id = payload.get("channel_id") + if channel_id == bytes(32): + for cid in self.channels.keys(): + self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']})) + raise GracefulDisconnect + warned_channel_id = None + if channel_id in self.temp_id_to_id: + warned_channel_id = self.temp_id_to_id[channel_id] + elif channel_id in self.channels: + warned_channel_id = channel_id + if warned_channel_id: + # MAY disconnect. + self.ordered_message_queues[warned_channel_id].put_nowait((None, {'warning': payload['data']})) + raise GracefulDisconnect + def on_error(self, payload): self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}") - chan_id = payload.get("channel_id") - if chan_id in self.temp_id_to_id: - chan_id = self.temp_id_to_id[chan_id] - self.ordered_message_queues[chan_id].put_nowait((None, {'error':payload['data']})) + channel_id = payload.get("channel_id") + # if channel_id is all zero: MUST fail all channels with the sending node. + if channel_id == bytes(32): + for cid in self.channels.keys(): + self.schedule_force_closing(cid) + self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']})) + raise GracefulDisconnect + # otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node. + erring_channel_id = None + if channel_id in self.temp_id_to_id: + erring_channel_id = self.temp_id_to_id[channel_id] + elif channel_id in self.channels: + erring_channel_id = channel_id + if erring_channel_id: + self.schedule_force_closing(erring_channel_id) + self.ordered_message_queues[erring_channel_id].put_nowait((None, {'error': payload['data']})) + # disconnect now as there might be no one waiting on the queue... + # OTOH this means if there are waiters, they might not see the error + raise GracefulDisconnect + + async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=True): + """Sends a warning and disconnects if close_connection. + + Note: + * channel_id is the temporary channel id when the channel id is not yet available + + A sending node: + MAY set channel_id to all zero if the warning is not related to a specific channel. + + when failure was caused by an invalid signature check: + * SHOULD include the raw, hex-encoded transaction in reply to a funding_created, + funding_signed, closing_signed, or commitment_signed message. + """ + assert isinstance(channel_id, bytes) + encoded_data = b'' if not message else message.encode('ascii') + self.send_message('warning', channel_id=channel_id, data=encoded_data, len=len(encoded_data)) + if close_connection: + raise GracefulDisconnect + + async def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False): + """Sends an error message and force closes the channel. + + Note: + * channel_id is the temporary channel id when the channel id is not yet available + + A sending node: + * SHOULD send error for protocol violations or internal errors that make channels + unusable or that make further communication unusable. + * SHOULD send error with the unknown channel_id in reply to messages of type + 32-255 related to unknown channels. + * MUST fail the channel(s) referred to by the error message. + * MAY set channel_id to all zero to indicate all channels. + + when failure was caused by an invalid signature check: + * SHOULD include the raw, hex-encoded transaction in reply to a funding_created, + funding_signed, closing_signed, or commitment_signed message. + """ + assert isinstance(channel_id, bytes) + encoded_data = b'' if not message else message.encode('ascii') + self.send_message('error', channel_id=channel_id, data=encoded_data, len=len(encoded_data)) + # MUST fail the channel(s) referred to by the error message: + # we may violate this with force_close_channel + if force_close_channel: + # channel_id of zero means that the error refers to all channels + if channel_id == bytes(32): + for channel_id in self.channels: + self.schedule_force_closing(channel_id) + else: + self.schedule_force_closing(channel_id) + raise GracefulDisconnect def on_ping(self, payload): l = payload['num_pong_bytes'] @@ -242,7 +326,9 @@ class Peer(Logger): q = self.ordered_message_queues[channel_id] name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT) if payload.get('error'): - raise Exception('Remote peer reported error [DO NOT TRUST THIS MESSAGE]: ' + repr(payload.get('error'))) + raise GracefulDisconnect(f'Waiting for {expected_name} failed due to an error sent by the peer.') + elif payload.get('warning'): + raise GracefulDisconnect(f'Waiting for {expected_name} failed due to a warning sent by the peer.') if name != expected_name: raise Exception(f"Received unexpected '{name}'") return payload @@ -956,6 +1042,13 @@ class Peer(Logger): your_last_per_commitment_secret=0, my_current_per_commitment_point=latest_point) + def schedule_force_closing(self, channel_id: bytes): + channels_with_peer = list(self.channels.keys()) + channels_with_peer.extend(self.temp_id_to_id.values()) + if channel_id not in channels_with_peer: + raise ValueError(f"channel {channel_id.hex()} does not belong to this peer") + self.lnworker.schedule_force_closing(channel_id) + def on_channel_reestablish(self, chan, msg): their_next_local_ctn = msg["next_commitment_number"] their_oldest_unrevoked_remote_ctn = msg["next_revocation_number"] diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 7cc0716d8..04f3e3f09 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -38,6 +38,7 @@ from electrum.lnonion import OnionFailureCode from electrum.lnutil import derive_payment_secret_from_payment_preimage from electrum.lnutil import LOCAL, REMOTE from electrum.invoices import PR_PAID, PR_UNPAID +from electrum.interface import GracefulDisconnect from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -1129,6 +1130,38 @@ class TestPeer(TestCaseForTestnet): with self.assertRaises(concurrent.futures.CancelledError): run(f()) + @needs_test_with_all_chacha20_implementations + def test_warning(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def action(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True) + gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + async def f(): + await gath + with self.assertRaises(GracefulDisconnect): + run(f()) + + @needs_test_with_all_chacha20_implementations + def test_error(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def action(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True) + assert alice_channel.is_closed() + gath.cancel() + gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + async def f(): + await gath + with self.assertRaises(GracefulDisconnect): + run(f()) + @needs_test_with_all_chacha20_implementations def test_close_upfront_shutdown_script(self): alice_channel, bob_channel = create_test_channels()