diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 11bf48ef7..790f24d5b 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -60,6 +60,12 @@ if TYPE_CHECKING: LN_P2P_NETWORK_TIMEOUT = 20 +def channel_update(func): + def wrapper(peer, payload): + channel_id = payload["channel_id"] + chan = peer.channels[channel_id] + return func(peer, chan, payload) + return wrapper def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]: funding_txid_bytes = bytes.fromhex(funding_txid)[::-1] @@ -1004,11 +1010,10 @@ class Peer(Logger): return msg_hash, node_signature, bitcoin_signature - def on_update_fail_htlc(self, payload): - channel_id = payload["channel_id"] + @channel_update + def on_update_fail_htlc(self, chan, payload): htlc_id = int.from_bytes(payload["id"], "big") reason = payload["reason"] - chan = self.channels[channel_id] self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") chan.receive_fail_htlc(htlc_id, reason) self.maybe_send_commitment(chan) @@ -1081,23 +1086,23 @@ class Peer(Logger): chan.receive_new_commitment(payload["signature"], htlc_sigs) self.send_revoke_and_ack(chan) - def on_update_fulfill_htlc(self, update_fulfill_htlc_msg): - chan = self.channels[update_fulfill_htlc_msg["channel_id"]] - preimage = update_fulfill_htlc_msg["payment_preimage"] + @channel_update + def on_update_fulfill_htlc(self, chan, payload): + preimage = payload["payment_preimage"] payment_hash = sha256(preimage) - htlc_id = int.from_bytes(update_fulfill_htlc_msg["id"], "big") + htlc_id = int.from_bytes(payload["id"], "big") self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") chan.receive_htlc_settle(preimage, htlc_id) self.lnworker.save_preimage(payment_hash, preimage) self.maybe_send_commitment(chan) - def on_update_fail_malformed_htlc(self, payload): + @channel_update + def on_update_fail_malformed_htlc(self, chan, payload): self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}") - def on_update_add_htlc(self, payload): + @channel_update + def on_update_add_htlc(self, chan, payload): payment_hash = payload["payment_hash"] - channel_id = payload['channel_id'] - chan = self.channels[channel_id] htlc_id = int.from_bytes(payload["id"], 'big') self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big') @@ -1106,7 +1111,7 @@ class Peer(Logger): if chan.get_state() != channel_states.OPEN: raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}") if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX: - asyncio.ensure_future(self.lnworker.try_force_closing(channel_id)) + asyncio.ensure_future(self.lnworker.try_force_closing(chan.channel_id)) raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}") # add htlc htlc = UpdateAddHtlc( @@ -1243,10 +1248,9 @@ class Peer(Logger): self.lnworker.save_channel(chan) self.maybe_send_commitment(chan) - def on_update_fee(self, payload): - channel_id = payload["channel_id"] - feerate =int.from_bytes(payload["feerate_per_kw"], "big") - chan = self.channels[channel_id] + @channel_update + def on_update_fee(self, chan, payload): + feerate = int.from_bytes(payload["feerate_per_kw"], "big") chan.update_fee(feerate, False) async def maybe_update_fee(self, chan: Channel):