Browse Source

lnhtlc: save settled htlc amounts separately

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
699368b0b7
  1. 42
      electrum/lnhtlc.py
  2. 18
      electrum/tests/test_lnhtlc.py

42
electrum/lnhtlc.py

@ -63,7 +63,7 @@ class FeeUpdate:
return self.rate return self.rate
# implicit return None # implicit return None
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'settled', 'locked_in', 'htlc_id'])): class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])):
__slots__ = () __slots__ = ()
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if len(args) > 0: if len(args) > 0:
@ -71,7 +71,6 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash',
if type(args[1]) is str: if type(args[1]) is str:
args[1] = bfh(args[1]) args[1] = bfh(args[1])
args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()} args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
args[4] = {HTLCOwner(int(x)): y for x,y in args[4].items()}
return super().__new__(cls, *args) return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str: if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash']) kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
@ -79,10 +78,6 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash',
kwargs['locked_in'] = {LOCAL: None, REMOTE: None} kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
else: else:
kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in']} kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in']}
if 'settled' not in kwargs:
kwargs['settled'] = {LOCAL: None, REMOTE: None}
else:
kwargs['settled'] = {HTLCOwner(int(x)): y for x,y in kwargs['settled']}
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key") is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key")
@ -176,6 +171,8 @@ class HTLCStateMachine(PrintError):
self.lnwatcher = None self.lnwatcher = None
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
def set_state(self, state: str): def set_state(self, state: str):
self._state = state self._state = state
@ -429,10 +426,14 @@ class HTLCStateMachine(PrintError):
""" """
old_amount = self.htlcsum(self.gen_htlc_indices(subject, False)) old_amount = self.htlcsum(self.gen_htlc_indices(subject, False))
removed = []
for x in self.log[-subject]: for x in self.log[-subject]:
if type(x) is not SettleHtlc: continue if type(x) is not SettleHtlc: continue
htlc = self.lookup_htlc(self.log[subject], x.htlc_id) htlc = self.lookup_htlc(self.log[subject], x.htlc_id)
htlc.settled[subject] = self.current_height[subject] self.settled[subject].append(htlc.amount_msat)
self.log[subject].remove(htlc)
removed.append(x)
for x in removed: self.log[-subject].remove(x)
return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False)) return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False))
@ -465,15 +466,8 @@ class HTLCStateMachine(PrintError):
def balance(self, subject): def balance(self, subject):
initial = self.local_config.initial_msat if subject == LOCAL else self.remote_config.initial_msat initial = self.local_config.initial_msat if subject == LOCAL else self.remote_config.initial_msat
for direction in (SENT, RECEIVED): initial -= sum(self.settled[subject])
for x in self.log[-direction]: initial += sum(self.settled[-subject])
if type(x) is not SettleHtlc: continue
htlc = self.lookup_htlc(self.log[direction], x.htlc_id)
htlc_height = htlc.settled[direction]
if htlc_height is not None and htlc_height <= self.current_height[direction]:
# so we will subtract when direction == subject.
# example subject=LOCAL, direction=SENT: we subtract
initial -= htlc.amount_msat * subject * direction
assert initial == (self.local_state.amount_msat if subject == LOCAL else self.remote_state.amount_msat) assert initial == (self.local_state.amount_msat if subject == LOCAL else self.remote_state.amount_msat)
return initial return initial
@ -528,11 +522,10 @@ class HTLCStateMachine(PrintError):
_, this_point, _ = self.points _, this_point, _ = self.points
return self.make_commitment(LOCAL, this_point) return self.make_commitment(LOCAL, this_point)
@property def total_msat(self, sub):
def total_msat(self): return sum(self.settled[sub])
return {LOCAL: self.htlcsum(self.gen_htlc_indices(LOCAL, False, True)), REMOTE: self.htlcsum(self.gen_htlc_indices(REMOTE, False, True))}
def gen_htlc_indices(self, subject, only_pending, include_settled=False): def gen_htlc_indices(self, subject, only_pending):
""" """
only_pending: require the htlc's settlement to be pending (needs additional signatures/acks) only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
include_settled: include settled (totally done with) htlcs include_settled: include settled (totally done with) htlcs
@ -543,17 +536,10 @@ class HTLCStateMachine(PrintError):
for htlc in update_log: for htlc in update_log:
if type(htlc) is not UpdateAddHtlc: if type(htlc) is not UpdateAddHtlc:
continue continue
height = self.current_height[-subject]
locked_in = htlc.locked_in[subject] locked_in = htlc.locked_in[subject]
if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log): if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log):
continue continue
settled_cutoff = self.local_state.ctn if subject == LOCAL else self.remote_state.ctn
if not include_settled and htlc.settled[subject] is not None and settled_cutoff >= htlc.settled[subject]:
continue
res.append(htlc) res.append(htlc)
return res return res
@ -651,6 +637,8 @@ class HTLCStateMachine(PrintError):
"remote_log": [(type(x).__name__, x) for x in remote_filtered], "remote_log": [(type(x).__name__, x) for x in remote_filtered],
"local_log": [(type(x).__name__, x) for x in local_filtered], "local_log": [(type(x).__name__, x) for x in local_filtered],
"onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()}, "onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()},
"settled_local": self.settled[LOCAL],
"settled_remote": self.settled[REMOTE],
} }
# htlcs number must be monotonically increasing, # htlcs number must be monotonically increasing,

18
electrum/tests/test_lnhtlc.py

@ -201,10 +201,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
aliceSent = 0 aliceSent = 0
bobSent = 0 bobSent = 0
self.assertEqual(alice_channel.total_msat[SENT], aliceSent, "alice has incorrect milli-satoshis sent") self.assertEqual(alice_channel.total_msat(SENT), aliceSent, "alice has incorrect milli-satoshis sent")
self.assertEqual(alice_channel.total_msat[RECEIVED], bobSent, "alice has incorrect milli-satoshis received") self.assertEqual(alice_channel.total_msat(RECEIVED), bobSent, "alice has incorrect milli-satoshis received")
self.assertEqual(bob_channel.total_msat[SENT], bobSent, "bob has incorrect milli-satoshis sent") self.assertEqual(bob_channel.total_msat(SENT), bobSent, "bob has incorrect milli-satoshis sent")
self.assertEqual(bob_channel.total_msat[RECEIVED], aliceSent, "bob has incorrect milli-satoshis received") self.assertEqual(bob_channel.total_msat(RECEIVED), aliceSent, "bob has incorrect milli-satoshis received")
self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height") self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height")
self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height") self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height")
@ -242,10 +242,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
# should show 1 BTC received. They should also be at commitment height # should show 1 BTC received. They should also be at commitment height
# two, with the revocation window extended by 1 (5). # two, with the revocation window extended by 1 (5).
mSatTransferred = one_bitcoin_in_msat mSatTransferred = one_bitcoin_in_msat
self.assertEqual(alice_channel.total_msat[SENT], mSatTransferred, "alice satoshis sent incorrect") self.assertEqual(alice_channel.total_msat(SENT), mSatTransferred, "alice satoshis sent incorrect")
self.assertEqual(alice_channel.total_msat[RECEIVED], 0, "alice satoshis received incorrect") self.assertEqual(alice_channel.total_msat(RECEIVED), 0, "alice satoshis received incorrect")
self.assertEqual(bob_channel.total_msat[RECEIVED], mSatTransferred, "bob satoshis received incorrect") self.assertEqual(bob_channel.total_msat(RECEIVED), mSatTransferred, "bob satoshis received incorrect")
self.assertEqual(bob_channel.total_msat[SENT], 0, "bob satoshis sent incorrect") self.assertEqual(bob_channel.total_msat(SENT), 0, "bob satoshis sent incorrect")
self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height") self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height")
self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height") self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height")
@ -348,7 +348,7 @@ class TestLNHTLCDust(unittest.TestCase):
alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex) alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex)
force_state_transition(bob_channel, alice_channel) force_state_transition(bob_channel, alice_channel)
self.assertEqual(len(alice_channel.local_commitment.outputs()), 2) self.assertEqual(len(alice_channel.local_commitment.outputs()), 2)
self.assertEqual(alice_channel.total_msat[SENT] // 1000, htlcAmt) self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt)
def force_state_transition(chanA, chanB): def force_state_transition(chanA, chanB):
chanB.receive_new_commitment(*chanA.sign_next_commitment()) chanB.receive_new_commitment(*chanA.sign_next_commitment())

Loading…
Cancel
Save