diff --git a/lib/lnbase.py b/lib/lnbase.py index 5f8103c31..a9d00b354 100644 --- a/lib/lnbase.py +++ b/lib/lnbase.py @@ -37,6 +37,7 @@ from .util import PrintError, bh2u, print_error, bfh, profiler, xor_bytes from .transaction import opcodes, Transaction from .lnrouter import new_onion_packet, OnionHopsDataSingle, OnionPerHop from .lightning_payencode.lnaddr import lndecode +from .lnhtlc import UpdateAddHtlc, HTLCStateMachine from collections import namedtuple, defaultdict @@ -1035,27 +1036,16 @@ class Peer(PrintError): onion = new_onion_packet(self.node_keys, self.secret_key, hops_data, associated_data) msat_local = chan.local_state.amount_msat - (amount_msat + total_fee) msat_remote = chan.remote_state.amount_msat + (amount_msat + total_fee) + htlc = UpdateAddHtlc(amount_msat, payment_hash, final_cltv_expiry_with_deltas, total_fee) amount_msat += total_fee + self.send_message(gen_msg("update_add_htlc", channel_id=chan.channel_id, id=chan.local_state.next_htlc_id, cltv_expiry=final_cltv_expiry_with_deltas, amount_msat=amount_msat, payment_hash=payment_hash, onion_routing_packet=onion.to_bytes())) - their_local_htlc_pubkey = derive_pubkey(chan.remote_config.htlc_basepoint.pubkey, chan.remote_state.next_per_commitment_point) - their_remote_htlc_pubkey = derive_pubkey(chan.local_config.htlc_basepoint.pubkey, chan.remote_state.next_per_commitment_point) - their_remote_htlc_privkey_number = derive_privkey( - int.from_bytes(chan.local_config.htlc_basepoint.privkey, 'big'), - chan.remote_state.next_per_commitment_point) - their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, 'big') - # TODO check payment_hash - revocation_pubkey = derive_blinded_pubkey(chan.local_config.revocation_basepoint.pubkey, chan.remote_state.next_per_commitment_point) - htlcs_in_remote = [(make_received_htlc(revocation_pubkey, their_remote_htlc_pubkey, their_local_htlc_pubkey, payment_hash, final_cltv_expiry_with_deltas), amount_msat)] - remote_ctx = make_commitment_using_open_channel(chan, chan.remote_state.ctn + 1, False, chan.remote_state.next_per_commitment_point, - chan.remote_state.amount_msat, msat_local, htlcs_in_remote) - sig_64 = sign_and_get_sig_string(remote_ctx, chan.local_config, chan.remote_config) + m = HTLCStateMachine(chan) + m.add_htlc(htlc) - htlc_tx = make_htlc_tx_with_open_channel(chan, chan.remote_state.next_per_commitment_point, False, False, amount_msat, final_cltv_expiry_with_deltas, payment_hash, remote_ctx, 0) - # htlc_sig signs the HTLC transaction that spends from THEIR commitment transaction's received_htlc output - sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey)) - r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order()) - htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order()) + sig_64, htlc_sigs = m.sign_next_commitment() + htlc_sig = htlc_sigs[0][1] self.send_message(gen_msg("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=1, htlc_signature=htlc_sig)) diff --git a/lib/lnhtlc.py b/lib/lnhtlc.py index 426584c48..050d816d4 100644 --- a/lib/lnhtlc.py +++ b/lib/lnhtlc.py @@ -1,25 +1,28 @@ # ported from lnd 42de4400bff5105352d0552155f73589166d162b -from .lnbase import * +from ecdsa.util import sigencode_string_canonize, sigdecode_der +from .util import bfh, PrintError +from collections import namedtuple +from ecdsa.curves import SECP256k1 +from .crypto import sha256 SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"]) RevokeAndAck = namedtuple("RevokeAndAck", ["height", "per_commitment_secret", "next_per_commitment_point"]) class UpdateAddHtlc: - def __init__(self, amount_msat, payment_hash, cltv_expiry, final_cltv_expiry_with_deltas): + def __init__(self, amount_msat, payment_hash, cltv_expiry, total_fee): self.amount_msat = amount_msat self.payment_hash = payment_hash self.cltv_expiry = cltv_expiry + self.total_fee = total_fee # the height the htlc was locked in at, or None - self.locked_in = None - - # this field is not in update_add_htlc but we need to to make the right htlcs - self.final_cltv_expiry_with_deltas = final_cltv_expiry_with_deltas + self.r_locked_in = None + self.l_locked_in = None self.htlc_id = None def as_tuple(self): - return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.locked_in, self.final_cltv_expiry_with_deltas) + return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.r_locked_in, self.l_locked_in, self.total_fee) def __hash__(self): return hash(self.as_tuple()) @@ -43,14 +46,15 @@ class HTLCStateMachine(PrintError): def diagnostic_name(self): return str(self.name) - def __init__(self, state: OpenChannel, name = None): + def __init__(self, state, name = None): self.state = state self.local_update_log = [] self.remote_update_log = [] self.name = name - self.current_height = 0 + self.l_current_height = 0 + self.r_current_height = 0 self.total_msat_sent = 0 self.total_msat_received = 0 @@ -92,9 +96,14 @@ class HTLCStateMachine(PrintError): any). The HTLC signatures are sorted according to the BIP 69 order of the HTLC's on the commitment transaction. """ + from .lnbase import sign_and_get_sig_string, derive_privkey, make_htlc_tx_with_open_channel + self.l_current_height += 1 + for htlc in self.local_update_log: + if not type(htlc) is UpdateAddHtlc: continue + if htlc.l_locked_in is None: htlc.l_locked_in = self.l_current_height for htlc in self.remote_update_log: if not type(htlc) is UpdateAddHtlc: continue - if htlc.locked_in is None: htlc.locked_in = self.current_height + if htlc.r_locked_in is None: htlc.r_locked_in = self.r_current_height self.print_error("sign_next_commitment") sig_64 = sign_and_get_sig_string(self.remote_commitment, self.state.local_config, self.state.remote_config) @@ -106,17 +115,18 @@ class HTLCStateMachine(PrintError): for_us = False - htlcs = self.htlcs_in_remote # TODO also htlcs_in_local - assert len(htlcs) <= 1 htlcsigs = [] - for htlc in htlcs: - original_htlc_output_index = 0 - we_receive = True # when we do htlcs_in_local, we need to flip this flag - htlc_tx = make_htlc_tx_with_open_channel(self.state, self.state.remote_state.next_per_commitment_point, for_us, we_receive, htlc.amount_msat, htlc.final_cltv_expiry_with_deltas, htlc.payment_hash, self.remote_commitment, original_htlc_output_index) - sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey)) - r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order()) - htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order()) - htlcsigs.append(htlc_sig) + for we_receive, htlcs in zip([True, False], [self.htlcs_in_remote, self.htlcs_in_local]): + assert len(htlcs) <= 1 + for htlc in htlcs: + original_htlc_output_index = 0 + args = [self.state.remote_state.next_per_commitment_point, for_us, we_receive, htlc.amount_msat + htlc.total_fee, htlc.cltv_expiry, htlc.payment_hash, self.remote_commitment, original_htlc_output_index] + print("args", args) + htlc_tx = make_htlc_tx_with_open_channel(self.state, *args) + sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey)) + r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order()) + htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order()) + htlcsigs.append((htlc_tx, htlc_sig)) return sig_64, htlcsigs @@ -145,7 +155,8 @@ class HTLCStateMachine(PrintError): transaction. This return value allows callers to act once an HTLC has been locked into our commitment transaction. """ - self.current_height += 1 + from .lnbase import get_per_commitment_secret_from_seed, secret_to_pubkey + self.r_current_height += 1 self.print_error("revoke_current_commitment") chan = self.state @@ -163,7 +174,7 @@ class HTLCStateMachine(PrintError): ) ) - return RevokeAndAck(self.current_height, last_secret, next_point), "current htlcs" + return RevokeAndAck(self.r_current_height, last_secret, next_point), "current htlcs" def receive_revocation(self, revocation): """ @@ -186,7 +197,7 @@ class HTLCStateMachine(PrintError): settle_fails2.append(x) if revocation.height is not None: - adds2 = list(x for x in self.htlcs_in_remote if x.locked_in == revocation.height) + adds2 = list(x for x in self.htlcs_in_remote if x.r_locked_in == revocation.height) class FwdPkg: adds = adds2 @@ -196,10 +207,11 @@ class HTLCStateMachine(PrintError): self.total_msat_sent += self.lookup_htlc(self.local_update_log, x.htlc_id).amount_msat # increase received_msat counter for htlc's that have been settled - adds2 = self.gen_htlc_indices(self.remote_update_log) - for x in adds2: - if SettleHtlc(x) in self.local_update_log: - self.total_msat_received += self.lookup_htlc(self.remote_update_log, x).amount_msat + adds2 = self.gen_htlc_indices("remote") + for htlc in adds2: + htlc_id = htlc.htlc_id + if SettleHtlc(htlc_id) in self.local_update_log: + self.total_msat_received += self.lookup_htlc(self.remote_update_log, htlc_id).amount_msat # log compaction (remove entries relating to htlc's that have been settled) @@ -241,14 +253,17 @@ class HTLCStateMachine(PrintError): @staticmethod def htlcsum(htlcs): - return sum(x.amount_msat for x in htlcs) + return sum(x.amount_msat for x in htlcs), sum(x.total_fee for x in htlcs) @property def remote_commitment(self): + from .lnbase import make_commitment_using_open_channel, make_received_htlc, make_offered_htlc, derive_pubkey, derive_blinded_pubkey + htlc_value_local, total_fee_local = self.htlcsum(self.htlcs_in_local) + htlc_value_remote, total_fee_remote = self.htlcsum(self.htlcs_in_remote) local_msat = self.state.local_state.amount_msat -\ - self.htlcsum(self.htlcs_in_local) + htlc_value_local remote_msat = self.state.remote_state.amount_msat -\ - self.htlcsum(self.htlcs_in_remote) + htlc_value_remote assert local_msat > 0 assert remote_msat > 0 @@ -262,25 +277,30 @@ class HTLCStateMachine(PrintError): htlcs_in_local = [] for htlc in self.htlcs_in_local: htlcs_in_local.append( - ( make_received_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat)) + ( make_received_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat + total_fee_local)) htlcs_in_remote = [] for htlc in self.htlcs_in_remote: htlcs_in_remote.append( - ( make_offered_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash), htlc.amount_msat)) + ( make_offered_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash), htlc.amount_msat + total_fee_remote)) commit = make_commitment_using_open_channel(self.state, self.state.remote_state.ctn + 1, - True, this_point, - remote_msat, local_msat, htlcs_in_local + htlcs_in_remote) + False, this_point, + remote_msat - total_fee_remote, local_msat - total_fee_local, htlcs_in_local + htlcs_in_remote) assert len(commit.outputs()) == 2 + len(htlcs_in_local) + len(htlcs_in_remote) return commit @property def local_commitment(self): + from .lnbase import make_commitment_using_open_channel, make_received_htlc, make_offered_htlc, derive_pubkey, derive_blinded_pubkey, get_per_commitment_secret_from_seed, secret_to_pubkey + htlc_value_local, total_fee_local = self.htlcsum(self.htlcs_in_local) + htlc_value_remote, total_fee_remote = self.htlcsum(self.htlcs_in_remote) + print("htlc_value_local, total_fee_local", htlc_value_local, total_fee_local) local_msat = self.state.local_state.amount_msat -\ - self.htlcsum(self.htlcs_in_local) + htlc_value_local + print("htlc_value_remote, total_fee_remote", htlc_value_remote, total_fee_remote) remote_msat = self.state.remote_state.amount_msat -\ - self.htlcsum(self.htlcs_in_remote) + htlc_value_remote assert local_msat > 0 assert remote_msat > 0 @@ -295,35 +315,44 @@ class HTLCStateMachine(PrintError): htlcs_in_local = [] for htlc in self.htlcs_in_local: + print("adding local htlc", htlc) htlcs_in_local.append( - ( make_offered_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash), htlc.amount_msat)) + ( make_offered_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash), htlc.amount_msat + total_fee_local)) htlcs_in_remote = [] for htlc in self.htlcs_in_remote: + print("adding remote htlc", htlc) htlcs_in_remote.append( - ( make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat)) + ( make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat + total_fee_remote)) commit = make_commitment_using_open_channel(self.state, self.state.local_state.ctn + 1, True, this_point, - local_msat, remote_msat, htlcs_in_local + htlcs_in_remote) + local_msat - total_fee_local, remote_msat - total_fee_remote, htlcs_in_local + htlcs_in_remote) assert len(commit.outputs()) == 2 + len(htlcs_in_local) + len(htlcs_in_remote) return commit - def gen_htlc_indices(self, update_log): - for num, htlc in enumerate(update_log): + def gen_htlc_indices(self, subject): + assert subject in ["local", "remote"] + update_log = (self.remote_update_log if subject == "remote" else self.local_update_log) + res = [] + for htlc in update_log: if type(htlc) is not UpdateAddHtlc: continue - if htlc.locked_in is None or htlc.locked_in < self.current_height: + height = (self.r_current_height if subject == "remote" else self.l_current_height) + locked_in = (htlc.r_locked_in if subject == "remote" else htlc.l_locked_in) + + if locked_in is None or locked_in < height: continue - yield num + res.append(htlc) + return res @property def htlcs_in_local(self): - return [self.local_update_log[x] for x in self.gen_htlc_indices(self.local_update_log)] + return self.gen_htlc_indices("local") @property def htlcs_in_remote(self): - return [self.remote_update_log[x] for x in self.gen_htlc_indices(self.remote_update_log)] + return self.gen_htlc_indices("remote") def settle_htlc(self, preimage, htlc_id, source_ref, dest_ref, close_key): """ diff --git a/lib/lnworker.py b/lib/lnworker.py index 9e9d85d44..2be1bf956 100644 --- a/lib/lnworker.py +++ b/lib/lnworker.py @@ -189,7 +189,7 @@ class LNWorker(PrintError): amount_msat = int(addr.amount * COIN * 1000) path = self.path_finder.find_path_for_payment(self.pubkey, invoice_pubkey, amount_msat) if path is None: - return "No path found" + raise Exception("No path found") node_id, short_channel_id = path[0] peer = self.peers[node_id] for chan in self.channels.values(): diff --git a/lib/tests/test_lnhtlc.py b/lib/tests/test_lnhtlc.py index 68b41c085..3b767eb14 100644 --- a/lib/tests/test_lnhtlc.py +++ b/lib/tests/test_lnhtlc.py @@ -38,7 +38,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate return lnbase.OpenChannel( channel_id=channel_id, - short_channel_id=channel_id.to_bytes(32, "big")[:8], + short_channel_id=channel_id[:8], funding_outpoint=lnbase.Outpoint(funding_txid, funding_index), local_config=local_config, remote_config=remote_config, @@ -118,7 +118,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): payment_hash = paymentHash, amount_msat = htlcAmt, cltv_expiry = 5, - final_cltv_expiry_with_deltas = 5 + total_fee = 0 ) # First Alice adds the outgoing HTLC to her local channel's state @@ -134,6 +134,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): # cover the HTLC. aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment() + self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature") + # Bob receives this signature message, and checks that this covers the # state he has in his remote log. This includes the HTLC just sent # from Alice. @@ -153,8 +155,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): # her prior commitment transaction. Alice shouldn't have any HTLCs to # forward since she's sending an outgoing HTLC. fwdPkg = alice_channel.receive_revocation(bobRevocation) - self.assertEqual(len(fwdPkg.adds), 0, "alice forwards %s add htlcs, should forward none"% len(fwdPkg.adds)) - self.assertEqual(len(fwdPkg.settle_fails), 0, "alice forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails)) + self.assertEqual(fwdPkg.adds, [], "alice forwards %s add htlcs, should forward none"% len(fwdPkg.adds)) + self.assertEqual(fwdPkg.settle_fails, [], "alice forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails)) # Alice then processes bob's signature, and since she just received # the revocation, she expect this signature to cover everything up to @@ -170,7 +172,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): # into both commitment transactions. fwdPkg = bob_channel.receive_revocation(aliceRevocation) self.assertEqual(len(fwdPkg.adds), 1, "bob forwards %s add htlcs, should only forward one"% len(fwdPkg.adds)) - self.assertEqual(len(fwdPkg.settle_fails), 0, "bob forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails)) + self.assertEqual(fwdPkg.settle_fails, [], "bob forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails)) # At this point, both sides should have the proper number of satoshis # sent, and commitment height updated within their local channel @@ -182,8 +184,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): self.assertEqual(alice_channel.total_msat_received, bobSent, "alice has incorrect milli-satoshis received %s vs %s"% (alice_channel.total_msat_received, bobSent)) self.assertEqual(bob_channel.total_msat_sent, bobSent, "bob has incorrect milli-satoshis sent %s vs %s"% (bob_channel.total_msat_sent, bobSent)) self.assertEqual(bob_channel.total_msat_received, aliceSent, "bob has incorrect milli-satoshis received %s vs %s"% (bob_channel.total_msat_received, aliceSent)) - self.assertEqual(bob_channel.current_height, 1, "bob has incorrect commitment height, %s vs %s"% (bob_channel.current_height, 1)) - self.assertEqual(alice_channel.current_height, 1, "alice has incorrect commitment height, %s vs %s"% (alice_channel.current_height, 1)) + self.assertEqual(bob_channel.l_current_height, 1, "bob has incorrect commitment height, %s vs %s"% (bob_channel.l_current_height, 1)) + self.assertEqual(alice_channel.l_current_height, 1, "alice has incorrect commitment height, %s vs %s"% (alice_channel.l_current_height, 1)) # Both commitment transactions should have three outputs, and one of # them should be exactly the amount of the HTLC. @@ -227,8 +229,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): self.assertEqual(alice_channel.total_msat_received, 0, "alice satoshis received incorrect %s vs %s expected"% (alice_channel.total_msat_received, 0)) self.assertEqual(bob_channel.total_msat_received, mSatTransferred, "bob satoshis received incorrect %s vs %s expected"% (bob_channel.total_msat_received, mSatTransferred)) self.assertEqual(bob_channel.total_msat_sent, 0, "bob satoshis sent incorrect %s vs %s expected"% (bob_channel.total_msat_sent, 0)) - self.assertEqual(bob_channel.current_height, 2, "bob has incorrect commitment height, %s vs %s"% (bob_channel.current_height, 2)) - self.assertEqual(alice_channel.current_height, 2, "alice has incorrect commitment height, %s vs %s"% (alice_channel.current_height, 2)) + self.assertEqual(bob_channel.l_current_height, 2, "bob has incorrect commitment height, %s vs %s"% (bob_channel.l_current_height, 2)) + self.assertEqual(alice_channel.l_current_height, 2, "alice has incorrect commitment height, %s vs %s"% (alice_channel.l_current_height, 2)) # The logs of both sides should now be cleared since the entry adding # the HTLC should have been removed once both sides receive the