Browse Source

ln: integrate lnhtlc in lnbase, fix multiple lnhtlc bugs

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 7 years ago
committed by ThomasV
parent
commit
4d25933898
  1. 24
      lib/lnbase.py
  2. 119
      lib/lnhtlc.py
  3. 2
      lib/lnworker.py
  4. 20
      lib/tests/test_lnhtlc.py

24
lib/lnbase.py

@ -37,6 +37,7 @@ from .util import PrintError, bh2u, print_error, bfh, profiler, xor_bytes
from .transaction import opcodes, Transaction
from .lnrouter import new_onion_packet, OnionHopsDataSingle, OnionPerHop
from .lightning_payencode.lnaddr import lndecode
from .lnhtlc import UpdateAddHtlc, HTLCStateMachine
from collections import namedtuple, defaultdict
@ -1035,27 +1036,16 @@ class Peer(PrintError):
onion = new_onion_packet(self.node_keys, self.secret_key, hops_data, associated_data)
msat_local = chan.local_state.amount_msat - (amount_msat + total_fee)
msat_remote = chan.remote_state.amount_msat + (amount_msat + total_fee)
htlc = UpdateAddHtlc(amount_msat, payment_hash, final_cltv_expiry_with_deltas, total_fee)
amount_msat += total_fee
self.send_message(gen_msg("update_add_htlc", channel_id=chan.channel_id, id=chan.local_state.next_htlc_id, cltv_expiry=final_cltv_expiry_with_deltas, amount_msat=amount_msat, payment_hash=payment_hash, onion_routing_packet=onion.to_bytes()))
their_local_htlc_pubkey = derive_pubkey(chan.remote_config.htlc_basepoint.pubkey, chan.remote_state.next_per_commitment_point)
their_remote_htlc_pubkey = derive_pubkey(chan.local_config.htlc_basepoint.pubkey, chan.remote_state.next_per_commitment_point)
their_remote_htlc_privkey_number = derive_privkey(
int.from_bytes(chan.local_config.htlc_basepoint.privkey, 'big'),
chan.remote_state.next_per_commitment_point)
their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, 'big')
# TODO check payment_hash
revocation_pubkey = derive_blinded_pubkey(chan.local_config.revocation_basepoint.pubkey, chan.remote_state.next_per_commitment_point)
htlcs_in_remote = [(make_received_htlc(revocation_pubkey, their_remote_htlc_pubkey, their_local_htlc_pubkey, payment_hash, final_cltv_expiry_with_deltas), amount_msat)]
remote_ctx = make_commitment_using_open_channel(chan, chan.remote_state.ctn + 1, False, chan.remote_state.next_per_commitment_point,
chan.remote_state.amount_msat, msat_local, htlcs_in_remote)
sig_64 = sign_and_get_sig_string(remote_ctx, chan.local_config, chan.remote_config)
m = HTLCStateMachine(chan)
m.add_htlc(htlc)
htlc_tx = make_htlc_tx_with_open_channel(chan, chan.remote_state.next_per_commitment_point, False, False, amount_msat, final_cltv_expiry_with_deltas, payment_hash, remote_ctx, 0)
# htlc_sig signs the HTLC transaction that spends from THEIR commitment transaction's received_htlc output
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey))
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order())
sig_64, htlc_sigs = m.sign_next_commitment()
htlc_sig = htlc_sigs[0][1]
self.send_message(gen_msg("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=1, htlc_signature=htlc_sig))

119
lib/lnhtlc.py

@ -1,25 +1,28 @@
# ported from lnd 42de4400bff5105352d0552155f73589166d162b
from .lnbase import *
from ecdsa.util import sigencode_string_canonize, sigdecode_der
from .util import bfh, PrintError
from collections import namedtuple
from ecdsa.curves import SECP256k1
from .crypto import sha256
SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"])
RevokeAndAck = namedtuple("RevokeAndAck", ["height", "per_commitment_secret", "next_per_commitment_point"])
class UpdateAddHtlc:
def __init__(self, amount_msat, payment_hash, cltv_expiry, final_cltv_expiry_with_deltas):
def __init__(self, amount_msat, payment_hash, cltv_expiry, total_fee):
self.amount_msat = amount_msat
self.payment_hash = payment_hash
self.cltv_expiry = cltv_expiry
self.total_fee = total_fee
# the height the htlc was locked in at, or None
self.locked_in = None
# this field is not in update_add_htlc but we need to to make the right htlcs
self.final_cltv_expiry_with_deltas = final_cltv_expiry_with_deltas
self.r_locked_in = None
self.l_locked_in = None
self.htlc_id = None
def as_tuple(self):
return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.locked_in, self.final_cltv_expiry_with_deltas)
return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.r_locked_in, self.l_locked_in, self.total_fee)
def __hash__(self):
return hash(self.as_tuple())
@ -43,14 +46,15 @@ class HTLCStateMachine(PrintError):
def diagnostic_name(self):
return str(self.name)
def __init__(self, state: OpenChannel, name = None):
def __init__(self, state, name = None):
self.state = state
self.local_update_log = []
self.remote_update_log = []
self.name = name
self.current_height = 0
self.l_current_height = 0
self.r_current_height = 0
self.total_msat_sent = 0
self.total_msat_received = 0
@ -92,9 +96,14 @@ class HTLCStateMachine(PrintError):
any). The HTLC signatures are sorted according to the BIP 69 order of the
HTLC's on the commitment transaction.
"""
from .lnbase import sign_and_get_sig_string, derive_privkey, make_htlc_tx_with_open_channel
self.l_current_height += 1
for htlc in self.local_update_log:
if not type(htlc) is UpdateAddHtlc: continue
if htlc.l_locked_in is None: htlc.l_locked_in = self.l_current_height
for htlc in self.remote_update_log:
if not type(htlc) is UpdateAddHtlc: continue
if htlc.locked_in is None: htlc.locked_in = self.current_height
if htlc.r_locked_in is None: htlc.r_locked_in = self.r_current_height
self.print_error("sign_next_commitment")
sig_64 = sign_and_get_sig_string(self.remote_commitment, self.state.local_config, self.state.remote_config)
@ -106,17 +115,18 @@ class HTLCStateMachine(PrintError):
for_us = False
htlcs = self.htlcs_in_remote # TODO also htlcs_in_local
assert len(htlcs) <= 1
htlcsigs = []
for htlc in htlcs:
original_htlc_output_index = 0
we_receive = True # when we do htlcs_in_local, we need to flip this flag
htlc_tx = make_htlc_tx_with_open_channel(self.state, self.state.remote_state.next_per_commitment_point, for_us, we_receive, htlc.amount_msat, htlc.final_cltv_expiry_with_deltas, htlc.payment_hash, self.remote_commitment, original_htlc_output_index)
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey))
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order())
htlcsigs.append(htlc_sig)
for we_receive, htlcs in zip([True, False], [self.htlcs_in_remote, self.htlcs_in_local]):
assert len(htlcs) <= 1
for htlc in htlcs:
original_htlc_output_index = 0
args = [self.state.remote_state.next_per_commitment_point, for_us, we_receive, htlc.amount_msat + htlc.total_fee, htlc.cltv_expiry, htlc.payment_hash, self.remote_commitment, original_htlc_output_index]
print("args", args)
htlc_tx = make_htlc_tx_with_open_channel(self.state, *args)
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey))
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order())
htlcsigs.append((htlc_tx, htlc_sig))
return sig_64, htlcsigs
@ -145,7 +155,8 @@ class HTLCStateMachine(PrintError):
transaction. This return value allows callers to act once an HTLC has been
locked into our commitment transaction.
"""
self.current_height += 1
from .lnbase import get_per_commitment_secret_from_seed, secret_to_pubkey
self.r_current_height += 1
self.print_error("revoke_current_commitment")
chan = self.state
@ -163,7 +174,7 @@ class HTLCStateMachine(PrintError):
)
)
return RevokeAndAck(self.current_height, last_secret, next_point), "current htlcs"
return RevokeAndAck(self.r_current_height, last_secret, next_point), "current htlcs"
def receive_revocation(self, revocation):
"""
@ -186,7 +197,7 @@ class HTLCStateMachine(PrintError):
settle_fails2.append(x)
if revocation.height is not None:
adds2 = list(x for x in self.htlcs_in_remote if x.locked_in == revocation.height)
adds2 = list(x for x in self.htlcs_in_remote if x.r_locked_in == revocation.height)
class FwdPkg:
adds = adds2
@ -196,10 +207,11 @@ class HTLCStateMachine(PrintError):
self.total_msat_sent += self.lookup_htlc(self.local_update_log, x.htlc_id).amount_msat
# increase received_msat counter for htlc's that have been settled
adds2 = self.gen_htlc_indices(self.remote_update_log)
for x in adds2:
if SettleHtlc(x) in self.local_update_log:
self.total_msat_received += self.lookup_htlc(self.remote_update_log, x).amount_msat
adds2 = self.gen_htlc_indices("remote")
for htlc in adds2:
htlc_id = htlc.htlc_id
if SettleHtlc(htlc_id) in self.local_update_log:
self.total_msat_received += self.lookup_htlc(self.remote_update_log, htlc_id).amount_msat
# log compaction (remove entries relating to htlc's that have been settled)
@ -241,14 +253,17 @@ class HTLCStateMachine(PrintError):
@staticmethod
def htlcsum(htlcs):
return sum(x.amount_msat for x in htlcs)
return sum(x.amount_msat for x in htlcs), sum(x.total_fee for x in htlcs)
@property
def remote_commitment(self):
from .lnbase import make_commitment_using_open_channel, make_received_htlc, make_offered_htlc, derive_pubkey, derive_blinded_pubkey
htlc_value_local, total_fee_local = self.htlcsum(self.htlcs_in_local)
htlc_value_remote, total_fee_remote = self.htlcsum(self.htlcs_in_remote)
local_msat = self.state.local_state.amount_msat -\
self.htlcsum(self.htlcs_in_local)
htlc_value_local
remote_msat = self.state.remote_state.amount_msat -\
self.htlcsum(self.htlcs_in_remote)
htlc_value_remote
assert local_msat > 0
assert remote_msat > 0
@ -262,25 +277,30 @@ class HTLCStateMachine(PrintError):
htlcs_in_local = []
for htlc in self.htlcs_in_local:
htlcs_in_local.append(
( make_received_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat))
( make_received_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat + total_fee_local))
htlcs_in_remote = []
for htlc in self.htlcs_in_remote:
htlcs_in_remote.append(
( make_offered_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash), htlc.amount_msat))
( make_offered_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash), htlc.amount_msat + total_fee_remote))
commit = make_commitment_using_open_channel(self.state, self.state.remote_state.ctn + 1,
True, this_point,
remote_msat, local_msat, htlcs_in_local + htlcs_in_remote)
False, this_point,
remote_msat - total_fee_remote, local_msat - total_fee_local, htlcs_in_local + htlcs_in_remote)
assert len(commit.outputs()) == 2 + len(htlcs_in_local) + len(htlcs_in_remote)
return commit
@property
def local_commitment(self):
from .lnbase import make_commitment_using_open_channel, make_received_htlc, make_offered_htlc, derive_pubkey, derive_blinded_pubkey, get_per_commitment_secret_from_seed, secret_to_pubkey
htlc_value_local, total_fee_local = self.htlcsum(self.htlcs_in_local)
htlc_value_remote, total_fee_remote = self.htlcsum(self.htlcs_in_remote)
print("htlc_value_local, total_fee_local", htlc_value_local, total_fee_local)
local_msat = self.state.local_state.amount_msat -\
self.htlcsum(self.htlcs_in_local)
htlc_value_local
print("htlc_value_remote, total_fee_remote", htlc_value_remote, total_fee_remote)
remote_msat = self.state.remote_state.amount_msat -\
self.htlcsum(self.htlcs_in_remote)
htlc_value_remote
assert local_msat > 0
assert remote_msat > 0
@ -295,35 +315,44 @@ class HTLCStateMachine(PrintError):
htlcs_in_local = []
for htlc in self.htlcs_in_local:
print("adding local htlc", htlc)
htlcs_in_local.append(
( make_offered_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash), htlc.amount_msat))
( make_offered_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash), htlc.amount_msat + total_fee_local))
htlcs_in_remote = []
for htlc in self.htlcs_in_remote:
print("adding remote htlc", htlc)
htlcs_in_remote.append(
( make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat))
( make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat + total_fee_remote))
commit = make_commitment_using_open_channel(self.state, self.state.local_state.ctn + 1,
True, this_point,
local_msat, remote_msat, htlcs_in_local + htlcs_in_remote)
local_msat - total_fee_local, remote_msat - total_fee_remote, htlcs_in_local + htlcs_in_remote)
assert len(commit.outputs()) == 2 + len(htlcs_in_local) + len(htlcs_in_remote)
return commit
def gen_htlc_indices(self, update_log):
for num, htlc in enumerate(update_log):
def gen_htlc_indices(self, subject):
assert subject in ["local", "remote"]
update_log = (self.remote_update_log if subject == "remote" else self.local_update_log)
res = []
for htlc in update_log:
if type(htlc) is not UpdateAddHtlc:
continue
if htlc.locked_in is None or htlc.locked_in < self.current_height:
height = (self.r_current_height if subject == "remote" else self.l_current_height)
locked_in = (htlc.r_locked_in if subject == "remote" else htlc.l_locked_in)
if locked_in is None or locked_in < height:
continue
yield num
res.append(htlc)
return res
@property
def htlcs_in_local(self):
return [self.local_update_log[x] for x in self.gen_htlc_indices(self.local_update_log)]
return self.gen_htlc_indices("local")
@property
def htlcs_in_remote(self):
return [self.remote_update_log[x] for x in self.gen_htlc_indices(self.remote_update_log)]
return self.gen_htlc_indices("remote")
def settle_htlc(self, preimage, htlc_id, source_ref, dest_ref, close_key):
"""

2
lib/lnworker.py

@ -189,7 +189,7 @@ class LNWorker(PrintError):
amount_msat = int(addr.amount * COIN * 1000)
path = self.path_finder.find_path_for_payment(self.pubkey, invoice_pubkey, amount_msat)
if path is None:
return "No path found"
raise Exception("No path found")
node_id, short_channel_id = path[0]
peer = self.peers[node_id]
for chan in self.channels.values():

20
lib/tests/test_lnhtlc.py

@ -38,7 +38,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
return lnbase.OpenChannel(
channel_id=channel_id,
short_channel_id=channel_id.to_bytes(32, "big")[:8],
short_channel_id=channel_id[:8],
funding_outpoint=lnbase.Outpoint(funding_txid, funding_index),
local_config=local_config,
remote_config=remote_config,
@ -118,7 +118,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
payment_hash = paymentHash,
amount_msat = htlcAmt,
cltv_expiry = 5,
final_cltv_expiry_with_deltas = 5
total_fee = 0
)
# First Alice adds the outgoing HTLC to her local channel's state
@ -134,6 +134,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
# cover the HTLC.
aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment()
self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature")
# Bob receives this signature message, and checks that this covers the
# state he has in his remote log. This includes the HTLC just sent
# from Alice.
@ -153,8 +155,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
# her prior commitment transaction. Alice shouldn't have any HTLCs to
# forward since she's sending an outgoing HTLC.
fwdPkg = alice_channel.receive_revocation(bobRevocation)
self.assertEqual(len(fwdPkg.adds), 0, "alice forwards %s add htlcs, should forward none"% len(fwdPkg.adds))
self.assertEqual(len(fwdPkg.settle_fails), 0, "alice forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails))
self.assertEqual(fwdPkg.adds, [], "alice forwards %s add htlcs, should forward none"% len(fwdPkg.adds))
self.assertEqual(fwdPkg.settle_fails, [], "alice forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails))
# Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to
@ -170,7 +172,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
# into both commitment transactions.
fwdPkg = bob_channel.receive_revocation(aliceRevocation)
self.assertEqual(len(fwdPkg.adds), 1, "bob forwards %s add htlcs, should only forward one"% len(fwdPkg.adds))
self.assertEqual(len(fwdPkg.settle_fails), 0, "bob forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails))
self.assertEqual(fwdPkg.settle_fails, [], "bob forwards %s settle/fail htlcs, should forward none"% len(fwdPkg.settle_fails))
# At this point, both sides should have the proper number of satoshis
# sent, and commitment height updated within their local channel
@ -182,8 +184,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
self.assertEqual(alice_channel.total_msat_received, bobSent, "alice has incorrect milli-satoshis received %s vs %s"% (alice_channel.total_msat_received, bobSent))
self.assertEqual(bob_channel.total_msat_sent, bobSent, "bob has incorrect milli-satoshis sent %s vs %s"% (bob_channel.total_msat_sent, bobSent))
self.assertEqual(bob_channel.total_msat_received, aliceSent, "bob has incorrect milli-satoshis received %s vs %s"% (bob_channel.total_msat_received, aliceSent))
self.assertEqual(bob_channel.current_height, 1, "bob has incorrect commitment height, %s vs %s"% (bob_channel.current_height, 1))
self.assertEqual(alice_channel.current_height, 1, "alice has incorrect commitment height, %s vs %s"% (alice_channel.current_height, 1))
self.assertEqual(bob_channel.l_current_height, 1, "bob has incorrect commitment height, %s vs %s"% (bob_channel.l_current_height, 1))
self.assertEqual(alice_channel.l_current_height, 1, "alice has incorrect commitment height, %s vs %s"% (alice_channel.l_current_height, 1))
# Both commitment transactions should have three outputs, and one of
# them should be exactly the amount of the HTLC.
@ -227,8 +229,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
self.assertEqual(alice_channel.total_msat_received, 0, "alice satoshis received incorrect %s vs %s expected"% (alice_channel.total_msat_received, 0))
self.assertEqual(bob_channel.total_msat_received, mSatTransferred, "bob satoshis received incorrect %s vs %s expected"% (bob_channel.total_msat_received, mSatTransferred))
self.assertEqual(bob_channel.total_msat_sent, 0, "bob satoshis sent incorrect %s vs %s expected"% (bob_channel.total_msat_sent, 0))
self.assertEqual(bob_channel.current_height, 2, "bob has incorrect commitment height, %s vs %s"% (bob_channel.current_height, 2))
self.assertEqual(alice_channel.current_height, 2, "alice has incorrect commitment height, %s vs %s"% (alice_channel.current_height, 2))
self.assertEqual(bob_channel.l_current_height, 2, "bob has incorrect commitment height, %s vs %s"% (bob_channel.l_current_height, 2))
self.assertEqual(alice_channel.l_current_height, 2, "alice has incorrect commitment height, %s vs %s"% (alice_channel.l_current_height, 2))
# The logs of both sides should now be cleared since the entry adding
# the HTLC should have been removed once both sides receive the

Loading…
Cancel
Save