diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 9ceed67f1..0dc31170e 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -241,22 +241,24 @@ class Channel(PrintError): script = funding_output_script(self.config[LOCAL], self.config[REMOTE]) return redeem_script_to_address('p2wsh', script) - def add_htlc(self, htlc): + def add_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: """ AddHTLC adds an HTLC to the state machine's local update log. This method should be called when preparing to send an outgoing HTLC. This docstring is from LND. """ - assert type(htlc) is dict - self._check_can_pay(htlc['amount_msat']) - htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id) + if isinstance(htlc, dict): # legacy conversion # FIXME remove + htlc = UpdateAddHtlc(**htlc) + assert isinstance(htlc, UpdateAddHtlc) + self._check_can_pay(htlc.amount_msat) + htlc = htlc._replace(htlc_id=self.config[LOCAL].next_htlc_id) self.hm.send_htlc(htlc) self.print_error("add_htlc") self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1) - return htlc.htlc_id + return htlc - def receive_htlc(self, htlc): + def receive_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: """ ReceiveHTLC adds an HTLC to the state machine's remote update log. This method should be called in response to receiving a new HTLC from the remote @@ -264,8 +266,10 @@ class Channel(PrintError): This docstring is from LND. """ - assert type(htlc) is dict - htlc = UpdateAddHtlc(**htlc, htlc_id = self.config[REMOTE].next_htlc_id) + if isinstance(htlc, dict): # legacy conversion # FIXME remove + htlc = UpdateAddHtlc(**htlc) + assert isinstance(htlc, UpdateAddHtlc) + htlc = htlc._replace(htlc_id=self.config[REMOTE].next_htlc_id) if self.available_to_spend(REMOTE) < htlc.amount_msat: raise RemoteMisbehaving('Remote dipped below channel reserve.' +\ f' Available at remote: {self.available_to_spend(REMOTE)},' +\ @@ -273,7 +277,7 @@ class Channel(PrintError): self.hm.recv_htlc(htlc) self.print_error("receive_htlc") self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1) - return htlc.htlc_id + return htlc def sign_next_commitment(self): """ diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 4803fb881..3c406e165 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -25,7 +25,8 @@ from . import constants from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions from .transaction import Transaction, TxOutput from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment, - process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage) + process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage, + ProcessedOnionPacket) from .lnchannel import Channel, RevokeAndAck, htlcsum from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore, @@ -841,7 +842,7 @@ class Peer(PrintError): await self._local_changed_events[chan.channel_id].wait() async def pay(self, route: List['RouteEdge'], chan: Channel, amount_msat: int, - payment_hash: bytes, min_final_cltv_expiry: int): + payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc: assert chan.get_state() == "OPEN", chan.get_state() assert amount_msat > 0, "amount_msat is not greater zero" # create onion packet @@ -851,22 +852,22 @@ class Peer(PrintError): secret_key = os.urandom(32) onion = new_onion_packet([x.node_id for x in route], secret_key, hops_data, associated_data=payment_hash) # create htlc - htlc = {'amount_msat':amount_msat, 'payment_hash':payment_hash, 'cltv_expiry':cltv} - htlc_id = chan.add_htlc(htlc) + htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv) + htlc = chan.add_htlc(htlc) remote_ctn = chan.get_current_ctn(REMOTE) - chan.onion_keys[htlc_id] = secret_key - self.attempted_route[(chan.channel_id, htlc_id)] = route + chan.onion_keys[htlc.htlc_id] = secret_key + self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route self.print_error(f"starting payment. route: {route}") self.send_message("update_add_htlc", channel_id=chan.channel_id, - id=htlc_id, - cltv_expiry=cltv, - amount_msat=amount_msat, - payment_hash=payment_hash, + id=htlc.htlc_id, + cltv_expiry=htlc.cltv_expiry, + amount_msat=htlc.amount_msat, + payment_hash=htlc.payment_hash, onion_routing_packet=onion.to_bytes()) self.remote_pending_updates[chan] = True await self.await_remote(chan, remote_ctn) - return UpdateAddHtlc(**htlc, htlc_id=htlc_id) + return htlc def send_revoke_and_ack(self, chan: Channel): rev, _ = chan.revoke_current_commitment() @@ -923,18 +924,29 @@ class Peer(PrintError): if cltv_expiry >= 500_000_000: pass # TODO fail the channel # add htlc - htlc = {'amount_msat': amount_msat_htlc, 'payment_hash':payment_hash, 'cltv_expiry':cltv_expiry} - htlc_id = chan.receive_htlc(htlc) + htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, payment_hash=payment_hash, cltv_expiry=cltv_expiry) + htlc = chan.receive_htlc(htlc) self.local_pending_updates[chan] = True local_ctn = chan.get_current_ctn(LOCAL) remote_ctn = chan.get_current_ctn(REMOTE) if processed_onion.are_we_final: - asyncio.ensure_future(self._maybe_fulfill_htlc(chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion)) + asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan, + htlc=htlc, + local_ctn=local_ctn, + remote_ctn=remote_ctn, + onion_packet=onion_packet, + processed_onion=processed_onion)) else: - asyncio.ensure_future(self._maybe_forward_htlc(chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion)) + asyncio.ensure_future(self._maybe_forward_htlc(chan=chan, + htlc=htlc, + local_ctn=local_ctn, + remote_ctn=remote_ctn, + onion_packet=onion_packet, + processed_onion=processed_onion)) @log_exceptions - async def _maybe_forward_htlc(self, chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion): + async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, + onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket): await self.await_local(chan, local_ctn) await self.await_remote(chan, remote_ctn) # Forward HTLC @@ -945,69 +957,70 @@ class Peer(PrintError): if next_chan is None or next_chan.get_state() != 'OPEN': self.print_error("cannot forward htlc", next_chan.get_state() if next_chan else None) reason = OnionRoutingFailureMessage(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'') - await self.fail_htlc(chan, htlc_id, onion_packet, reason) + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return self.print_error('forwarding htlc to', next_chan.node_id) next_cltv_expiry = int.from_bytes(dph.outgoing_cltv_value, 'big') next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big') - next_htlc = {'amount_msat':next_amount_msat_htlc, 'payment_hash':payment_hash, 'cltv_expiry':next_cltv_expiry} - next_htlc_id = next_chan.add_htlc(next_htlc) + next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry) + next_htlc = next_chan.add_htlc(next_htlc) next_remote_ctn = next_chan.get_current_ctn(REMOTE) next_peer.send_message( "update_add_htlc", channel_id=next_chan.channel_id, - id=next_htlc_id, + id=next_htlc.htlc_id, cltv_expiry=dph.outgoing_cltv_value, amount_msat=dph.amt_to_forward, - payment_hash=payment_hash, + payment_hash=next_htlc.payment_hash, onion_routing_packet=processed_onion.next_packet.to_bytes() ) next_peer.remote_pending_updates[next_chan] = True await next_peer.await_remote(next_chan, next_remote_ctn) # wait until we get paid - preimage = await next_peer.payment_preimages[payment_hash].get() + preimage = await next_peer.payment_preimages[next_htlc.payment_hash].get() # fulfill the original htlc - await self._fulfill_htlc(chan, htlc_id, preimage) + await self._fulfill_htlc(chan, htlc.htlc_id, preimage) self.print_error("htlc forwarded successfully") @log_exceptions - async def _maybe_fulfill_htlc(self, chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion): + async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, + onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket): await self.await_local(chan, local_ctn) await self.await_remote(chan, remote_ctn) try: - invoice = self.lnworker.get_invoice(payment_hash) - preimage = self.lnworker.get_preimage(payment_hash) + invoice = self.lnworker.get_invoice(htlc.payment_hash) + preimage = self.lnworker.get_preimage(htlc.payment_hash) except UnknownPaymentHash: reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_PAYMENT_HASH, data=b'') - await self.fail_htlc(chan, htlc_id, onion_packet, reason) + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return expected_received_msat = int(invoice.amount * bitcoin.COIN * 1000) if invoice.amount is not None else None if expected_received_msat is not None and \ - (amount_msat_htlc < expected_received_msat or amount_msat_htlc > 2 * expected_received_msat): + (htlc.amount_msat < expected_received_msat or htlc.amount_msat > 2 * expected_received_msat): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_PAYMENT_AMOUNT, data=b'') - await self.fail_htlc(chan, htlc_id, onion_packet, reason) + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return local_height = self.network.get_local_height() - if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > cltv_expiry: + if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > htlc.cltv_expiry: reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_EXPIRY_TOO_SOON, data=b'') - await self.fail_htlc(chan, htlc_id, onion_packet, reason) + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return cltv_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.outgoing_cltv_value, byteorder="big") - if cltv_from_onion != cltv_expiry: + if cltv_from_onion != htlc.cltv_expiry: reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY, - data=cltv_expiry.to_bytes(4, byteorder="big")) - await self.fail_htlc(chan, htlc_id, onion_packet, reason) + data=htlc.cltv_expiry.to_bytes(4, byteorder="big")) + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return amount_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.amt_to_forward, byteorder="big") - if amount_from_onion > amount_msat_htlc: + if amount_from_onion > htlc.amount_msat: reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, - data=amount_msat_htlc.to_bytes(8, byteorder="big")) - await self.fail_htlc(chan, htlc_id, onion_packet, reason) + data=htlc.amount_msat.to_bytes(8, byteorder="big")) + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return - self.network.trigger_callback('htlc_added', UpdateAddHtlc(**htlc, htlc_id=htlc_id), invoice, RECEIVED) + self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED) if self.network.config.debug_lightning_do_not_settle: return - await self._fulfill_htlc(chan, htlc_id, preimage) + await self._fulfill_htlc(chan, htlc.htlc_id, preimage) async def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): chan.settle_htlc(preimage, htlc_id) diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 5dec78e46..a7e653335 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -567,8 +567,10 @@ class LnGlobalFeatures(IntFlag): LN_GLOBAL_FEATURES_KNOWN_SET = set(LnGlobalFeatures) -class LNPeerAddr(namedtuple('LNPeerAddr', ['host', 'port', 'pubkey'])): - __slots__ = () +class LNPeerAddr(NamedTuple): + host: str + port: int + pubkey: bytes def __str__(self): return '{}@{}:{}'.format(bh2u(self.pubkey), self.host, self.port) @@ -663,13 +665,14 @@ def format_short_channel_id(short_channel_id: Optional[bytes]): + 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \ + 'x' + str(int.from_bytes(short_channel_id[6:], 'big')) + class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])): - """ - This whole class body is so that if you pass a hex-string as payment_hash, - it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings. - """ + # note: typing.NamedTuple cannot be used because we are overriding __new__ + __slots__ = () def __new__(cls, *args, **kwargs): + # if you pass a hex-string as payment_hash, it is decoded to bytes. + # Bytes can't be saved to disk, so we save hex-strings. if len(args) > 0: args = list(args) if type(args[1]) is str: @@ -677,5 +680,7 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', return super().__new__(cls, *args) if type(kwargs['payment_hash']) is str: kwargs['payment_hash'] = bfh(kwargs['payment_hash']) + if len(args) < 4 and 'htlc_id' not in kwargs: + kwargs['htlc_id'] = None return super().__new__(cls, **kwargs) diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py index 582a853b6..de4364174 100644 --- a/electrum/tests/test_lnchannel.py +++ b/electrum/tests/test_lnchannel.py @@ -207,13 +207,13 @@ class TestChannel(unittest.TestCase): # First Alice adds the outgoing HTLC to her local channel's state # update log. Then Alice sends this wire message over to Bob who adds # this htlc to his remote state update log. - self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict) + self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict).htlc_id self.assertNotEqual(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1), set()) before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) - self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict) + self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict).htlc_id self.assertEqual(1, self.bob_channel.hm.log[LOCAL]['ctn'] + 1) self.assertNotEqual(self.bob_channel.hm.htlcs_by_direction(LOCAL, RECEIVED, 1), set()) @@ -230,8 +230,8 @@ class TestChannel(unittest.TestCase): def test_concurrent_reversed_payment(self): self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') self.htlc_dict['amount_msat'] += 1000 - bob_idx = self.bob_channel.add_htlc(self.htlc_dict) - alice_idx = self.alice_channel.receive_htlc(self.htlc_dict) + self.bob_channel.add_htlc(self.htlc_dict) + self.alice_channel.receive_htlc(self.htlc_dict) self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4) @@ -481,8 +481,8 @@ class TestChannel(unittest.TestCase): self.assertNotEqual(tx5, tx6) self.htlc_dict['amount_msat'] *= 5 - bob_index = bob_channel.add_htlc(self.htlc_dict) - alice_index = alice_channel.receive_htlc(self.htlc_dict) + bob_index = bob_channel.add_htlc(self.htlc_dict).htlc_id + alice_index = alice_channel.receive_htlc(self.htlc_dict).htlc_id bob_channel.pending_commitment(REMOTE) alice_channel.pending_commitment(LOCAL) @@ -597,7 +597,7 @@ class TestChannel(unittest.TestCase): def test_sign_commitment_is_pure(self): force_state_transition(self.alice_channel, self.bob_channel) self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32) - aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict) + self.alice_channel.add_htlc(self.htlc_dict) before_signing = self.alice_channel.to_save() self.alice_channel.sign_next_commitment() after_signing = self.alice_channel.to_save() @@ -622,8 +622,8 @@ class TestAvailableToSpend(unittest.TestCase): 'cltv_expiry' : 5, } - alice_idx = alice_channel.add_htlc(htlc_dict) - bob_idx = bob_channel.receive_htlc(htlc_dict) + alice_idx = alice_channel.add_htlc(htlc_dict).htlc_id + bob_idx = bob_channel.receive_htlc(htlc_dict).htlc_id force_state_transition(alice_channel, bob_channel) bob_channel.fail_htlc(bob_idx) alice_channel.receive_fail_htlc(alice_idx) @@ -745,8 +745,8 @@ class TestChanReserve(unittest.TestCase): 'amount_msat' : int(2 * one_bitcoin_in_msat), 'cltv_expiry' : 5, } - alice_idx = self.alice_channel.add_htlc(htlc_dict) - bob_idx = self.bob_channel.receive_htlc(htlc_dict) + alice_idx = self.alice_channel.add_htlc(htlc_dict).htlc_id + bob_idx = self.bob_channel.receive_htlc(htlc_dict).htlc_id force_state_transition(self.alice_channel, self.bob_channel) self.check_bals(one_bitcoin_in_msat*3\ - self.alice_channel.pending_local_fee(), @@ -791,8 +791,8 @@ class TestDust(unittest.TestCase): } old_values = [x.value for x in bob_channel.current_commitment(LOCAL).outputs() ] - aliceHtlcIndex = alice_channel.add_htlc(htlc) - bobHtlcIndex = bob_channel.receive_htlc(htlc) + aliceHtlcIndex = alice_channel.add_htlc(htlc).htlc_id + bobHtlcIndex = bob_channel.receive_htlc(htlc).htlc_id force_state_transition(alice_channel, bob_channel) alice_ctx = alice_channel.current_commitment(LOCAL) bob_ctx = bob_channel.current_commitment(LOCAL)