From 699368b0b783dd5ade7740a0bffd11ba0fd056f0 Mon Sep 17 00:00:00 2001 From: Janus Date: Wed, 10 Oct 2018 16:58:31 +0200 Subject: [PATCH] lnhtlc: save settled htlc amounts separately --- electrum/lnhtlc.py | 42 +++++++++++++---------------------- electrum/tests/test_lnhtlc.py | 18 +++++++-------- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index eeca11d02..1c56e755d 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -63,7 +63,7 @@ class FeeUpdate: return self.rate # implicit return None -class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'settled', 'locked_in', 'htlc_id'])): +class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])): __slots__ = () def __new__(cls, *args, **kwargs): if len(args) > 0: @@ -71,7 +71,6 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', if type(args[1]) is str: args[1] = bfh(args[1]) args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()} - args[4] = {HTLCOwner(int(x)): y for x,y in args[4].items()} return super().__new__(cls, *args) if type(kwargs['payment_hash']) is str: kwargs['payment_hash'] = bfh(kwargs['payment_hash']) @@ -79,10 +78,6 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', kwargs['locked_in'] = {LOCAL: None, REMOTE: None} else: kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in']} - if 'settled' not in kwargs: - kwargs['settled'] = {LOCAL: None, REMOTE: None} - else: - kwargs['settled'] = {HTLCOwner(int(x)): y for x,y in kwargs['settled']} return super().__new__(cls, **kwargs) is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key") @@ -176,6 +171,8 @@ class HTLCStateMachine(PrintError): self.lnwatcher = None + self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])} + def set_state(self, state: str): self._state = state @@ -429,10 +426,14 @@ class HTLCStateMachine(PrintError): """ old_amount = self.htlcsum(self.gen_htlc_indices(subject, False)) + removed = [] for x in self.log[-subject]: if type(x) is not SettleHtlc: continue htlc = self.lookup_htlc(self.log[subject], x.htlc_id) - htlc.settled[subject] = self.current_height[subject] + self.settled[subject].append(htlc.amount_msat) + self.log[subject].remove(htlc) + removed.append(x) + for x in removed: self.log[-subject].remove(x) return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False)) @@ -465,15 +466,8 @@ class HTLCStateMachine(PrintError): def balance(self, subject): initial = self.local_config.initial_msat if subject == LOCAL else self.remote_config.initial_msat - for direction in (SENT, RECEIVED): - for x in self.log[-direction]: - if type(x) is not SettleHtlc: continue - htlc = self.lookup_htlc(self.log[direction], x.htlc_id) - htlc_height = htlc.settled[direction] - if htlc_height is not None and htlc_height <= self.current_height[direction]: - # so we will subtract when direction == subject. - # example subject=LOCAL, direction=SENT: we subtract - initial -= htlc.amount_msat * subject * direction + initial -= sum(self.settled[subject]) + initial += sum(self.settled[-subject]) assert initial == (self.local_state.amount_msat if subject == LOCAL else self.remote_state.amount_msat) return initial @@ -528,11 +522,10 @@ class HTLCStateMachine(PrintError): _, this_point, _ = self.points return self.make_commitment(LOCAL, this_point) - @property - def total_msat(self): - return {LOCAL: self.htlcsum(self.gen_htlc_indices(LOCAL, False, True)), REMOTE: self.htlcsum(self.gen_htlc_indices(REMOTE, False, True))} + def total_msat(self, sub): + return sum(self.settled[sub]) - def gen_htlc_indices(self, subject, only_pending, include_settled=False): + def gen_htlc_indices(self, subject, only_pending): """ only_pending: require the htlc's settlement to be pending (needs additional signatures/acks) include_settled: include settled (totally done with) htlcs @@ -543,17 +536,10 @@ class HTLCStateMachine(PrintError): for htlc in update_log: if type(htlc) is not UpdateAddHtlc: continue - height = self.current_height[-subject] locked_in = htlc.locked_in[subject] if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log): continue - - settled_cutoff = self.local_state.ctn if subject == LOCAL else self.remote_state.ctn - - if not include_settled and htlc.settled[subject] is not None and settled_cutoff >= htlc.settled[subject]: - continue - res.append(htlc) return res @@ -651,6 +637,8 @@ class HTLCStateMachine(PrintError): "remote_log": [(type(x).__name__, x) for x in remote_filtered], "local_log": [(type(x).__name__, x) for x in local_filtered], "onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()}, + "settled_local": self.settled[LOCAL], + "settled_remote": self.settled[REMOTE], } # htlcs number must be monotonically increasing, diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py index 91019ebbf..2a45007e7 100644 --- a/electrum/tests/test_lnhtlc.py +++ b/electrum/tests/test_lnhtlc.py @@ -201,10 +201,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): aliceSent = 0 bobSent = 0 - self.assertEqual(alice_channel.total_msat[SENT], aliceSent, "alice has incorrect milli-satoshis sent") - self.assertEqual(alice_channel.total_msat[RECEIVED], bobSent, "alice has incorrect milli-satoshis received") - self.assertEqual(bob_channel.total_msat[SENT], bobSent, "bob has incorrect milli-satoshis sent") - self.assertEqual(bob_channel.total_msat[RECEIVED], aliceSent, "bob has incorrect milli-satoshis received") + self.assertEqual(alice_channel.total_msat(SENT), aliceSent, "alice has incorrect milli-satoshis sent") + self.assertEqual(alice_channel.total_msat(RECEIVED), bobSent, "alice has incorrect milli-satoshis received") + self.assertEqual(bob_channel.total_msat(SENT), bobSent, "bob has incorrect milli-satoshis sent") + self.assertEqual(bob_channel.total_msat(RECEIVED), aliceSent, "bob has incorrect milli-satoshis received") self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height") self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height") @@ -242,10 +242,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): # should show 1 BTC received. They should also be at commitment height # two, with the revocation window extended by 1 (5). mSatTransferred = one_bitcoin_in_msat - self.assertEqual(alice_channel.total_msat[SENT], mSatTransferred, "alice satoshis sent incorrect") - self.assertEqual(alice_channel.total_msat[RECEIVED], 0, "alice satoshis received incorrect") - self.assertEqual(bob_channel.total_msat[RECEIVED], mSatTransferred, "bob satoshis received incorrect") - self.assertEqual(bob_channel.total_msat[SENT], 0, "bob satoshis sent incorrect") + self.assertEqual(alice_channel.total_msat(SENT), mSatTransferred, "alice satoshis sent incorrect") + self.assertEqual(alice_channel.total_msat(RECEIVED), 0, "alice satoshis received incorrect") + self.assertEqual(bob_channel.total_msat(RECEIVED), mSatTransferred, "bob satoshis received incorrect") + self.assertEqual(bob_channel.total_msat(SENT), 0, "bob satoshis sent incorrect") self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height") self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height") @@ -348,7 +348,7 @@ class TestLNHTLCDust(unittest.TestCase): alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex) force_state_transition(bob_channel, alice_channel) self.assertEqual(len(alice_channel.local_commitment.outputs()), 2) - self.assertEqual(alice_channel.total_msat[SENT] // 1000, htlcAmt) + self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt) def force_state_transition(chanA, chanB): chanB.receive_new_commitment(*chanA.sign_next_commitment())