From 47917d9e6cee0c5e46f0e89971cb310b8c6dc29e Mon Sep 17 00:00:00 2001 From: ThomasV Date: Wed, 9 Mar 2022 13:40:39 +0100 Subject: [PATCH] lnpeer: factorize on_warning/on_error code --- electrum/lnpeer.py | 53 +++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index b9432fda8..21a525d45 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -224,44 +224,35 @@ class Peer(Logger): if asyncio.iscoroutinefunction(f): asyncio.ensure_future(self.taskgroup.spawn(execution_result)) + def _get_channel_ids(self, channel_id): + # if channel_id is all zero: MUST fail all channels with the sending node. + # otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node. + # if no existing channel is referred to by `channel_id: MUST ignore the message. + if channel_id == bytes(32): + return self.channels.keys() + elif channel_id in self.temp_id_to_id: + return [self.temp_id_to_id[channel_id]] + elif channel_id in self.channels: + return [channel_id] + else: + return [] + def on_warning(self, payload): # TODO: we could need some reconnection logic here -> delayed reconnect self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}") - channel_id = payload.get("channel_id") - if channel_id == bytes(32): - for cid in self.channels.keys(): - self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']})) - raise GracefulDisconnect - warned_channel_id = None - if channel_id in self.temp_id_to_id: - warned_channel_id = self.temp_id_to_id[channel_id] - elif channel_id in self.channels: - warned_channel_id = channel_id - if warned_channel_id: - # MAY disconnect. - self.ordered_message_queues[warned_channel_id].put_nowait((None, {'warning': payload['data']})) + channel_ids = self._get_channel_ids(payload.get("channel_id")) + for cid in channel_ids: + self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']})) + if channel_ids: raise GracefulDisconnect def on_error(self, payload): self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}") - channel_id = payload.get("channel_id") - # if channel_id is all zero: MUST fail all channels with the sending node. - if channel_id == bytes(32): - for cid in self.channels.keys(): - self.schedule_force_closing(cid) - self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']})) - raise GracefulDisconnect - # otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node. - erring_channel_id = None - if channel_id in self.temp_id_to_id: - erring_channel_id = self.temp_id_to_id[channel_id] - elif channel_id in self.channels: - erring_channel_id = channel_id - if erring_channel_id: - self.schedule_force_closing(erring_channel_id) - self.ordered_message_queues[erring_channel_id].put_nowait((None, {'error': payload['data']})) - # disconnect now as there might be no one waiting on the queue... - # OTOH this means if there are waiters, they might not see the error + channel_ids = self._get_channel_ids(payload.get("channel_id")) + for cid in channel_ids: + self.schedule_force_closing(cid) + self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']})) + if channel_ids: raise GracefulDisconnect async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=True):