From ce54b5411e78e04f54169c01a3ba94e8bdca876d Mon Sep 17 00:00:00 2001 From: SomberNight Date: Wed, 14 Aug 2019 21:35:37 +0200 Subject: [PATCH] lnhtlc: htlcs_by_direction now returns dict keyed by htlc_id --- electrum/lnchannel.py | 11 ++++++----- electrum/lnhtlc.py | 22 +++++++++++----------- electrum/tests/test_lnchannel.py | 8 ++++---- electrum/tests/test_lnhtlc.py | 2 +- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 4e7148e70..396235016 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -200,7 +200,8 @@ class Channel(Logger): raise PaymentFailure(f'Not enough local balance. Have: {self.available_to_spend(LOCAL)}, Need: {amount_msat}') if len(self.hm.htlcs(LOCAL)) + 1 > self.config[REMOTE].max_accepted_htlcs: raise PaymentFailure('Too many HTLCs already in channel') - current_htlc_sum = htlcsum(self.hm.htlcs_by_direction(LOCAL, SENT)) + htlcsum(self.hm.htlcs_by_direction(LOCAL, RECEIVED)) + current_htlc_sum = (htlcsum(self.hm.htlcs_by_direction(LOCAL, SENT).values()) + + htlcsum(self.hm.htlcs_by_direction(LOCAL, RECEIVED).values())) if current_htlc_sum + amount_msat > self.config[REMOTE].max_htlc_value_in_flight_msat: raise PaymentFailure(f'HTLC value sum (sum of pending htlcs: {current_htlc_sum/1000} sat plus new htlc: {amount_msat/1000} sat) would exceed max allowed: {self.config[REMOTE].max_htlc_value_in_flight_msat/1000} sat') if amount_msat < self.config[REMOTE].htlc_minimum_msat: @@ -451,7 +452,7 @@ class Channel(Logger): assert type(whose) is HTLCOwner ctn = self.get_next_ctn(ctx_owner) return self.balance(whose, ctx_owner=ctx_owner, ctn=ctn)\ - - htlcsum(self.hm.htlcs_by_direction(ctx_owner, SENT, ctn)) + - htlcsum(self.hm.htlcs_by_direction(ctx_owner, SENT, ctn).values()) def available_to_spend(self, subject): """ @@ -484,7 +485,7 @@ class Channel(Logger): weight = HTLC_SUCCESS_WEIGHT else: weight = HTLC_TIMEOUT_WEIGHT - htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn) + htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn).values() fee_for_htlc = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000) return list(filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs)) @@ -647,8 +648,8 @@ class Channel(Logger): other = REMOTE if LOCAL == subject else LOCAL local_msat = self.balance(subject, ctx_owner=subject, ctn=ctn) remote_msat = self.balance(other, ctx_owner=subject, ctn=ctn) - received_htlcs = self.hm.htlcs_by_direction(subject, SENT if subject == LOCAL else RECEIVED, ctn) - sent_htlcs = self.hm.htlcs_by_direction(subject, RECEIVED if subject == LOCAL else SENT, ctn) + received_htlcs = self.hm.htlcs_by_direction(subject, SENT if subject == LOCAL else RECEIVED, ctn).values() + sent_htlcs = self.hm.htlcs_by_direction(subject, RECEIVED if subject == LOCAL else SENT, ctn).values() if subject != LOCAL: remote_msat -= htlcsum(received_htlcs) local_msat -= htlcsum(sent_htlcs) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 2ea4557d6..35562ab66 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -25,7 +25,7 @@ class HTLCManager: log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v for k, v in deepcopy(log).items()} for sub in (LOCAL, REMOTE): - log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()} + log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()} coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()} # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()} @@ -222,9 +222,9 @@ class HTLCManager: ##### Queries re HTLCs: def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, - ctn: int = None) -> Sequence[UpdateAddHtlc]: - """Return the list of received or sent (depending on direction) HTLCs - in subject's ctx at ctn. + ctn: int = None) -> Dict[int, UpdateAddHtlc]: + """Return the dict of received or sent (depending on direction) HTLCs + in subject's ctx at ctn, keyed by htlc_id. direction is relative to subject! """ @@ -232,19 +232,19 @@ class HTLCManager: assert type(direction) is Direction if ctn is None: ctn = self.ctn_oldest_unrevoked(subject) - l = [] + d = {} # subject's ctx # party is the proposer of the HTLCs party = subject if direction == SENT else subject.inverted() + settles = self.log[party]['settles'] + fails = self.log[party]['fails'] for htlc_id, ctns in self.log[party]['locked_in'].items(): if ctns[subject] is not None and ctns[subject] <= ctn: - settles = self.log[party]['settles'] - fails = self.log[party]['fails'] not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn not_failed = htlc_id not in fails or fails[htlc_id][subject] is None or fails[htlc_id][subject] > ctn if not_settled and not_failed: - l.append(self.log[party]['adds'][htlc_id]) - return l + d[htlc_id] = self.log[party]['adds'][htlc_id] + return d def htlcs(self, subject: HTLCOwner, ctn: int = None) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: """Return the list of HTLCs in subject's ctx at ctn.""" @@ -252,8 +252,8 @@ class HTLCManager: if ctn is None: ctn = self.ctn_oldest_unrevoked(subject) l = [] - l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn)] - l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)] + l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn).values()] + l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn).values()] return l def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py index 04252c2f9..e24d19a0c 100644 --- a/electrum/tests/test_lnchannel.py +++ b/electrum/tests/test_lnchannel.py @@ -202,7 +202,7 @@ class TestChannel(unittest.TestCase): # update log. Then Alice sends this wire message over to Bob who adds # this htlc to his remote state update log. self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict).htlc_id - self.assertNotEqual(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1), []) + self.assertNotEqual(list(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1).values()), []) before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) @@ -414,7 +414,7 @@ class TestChannel(unittest.TestCase): bobSig2, bobHtlcSigs2 = bob_channel.sign_next_commitment() self.assertEqual(len(bobHtlcSigs2), 0) - self.assertEqual(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED), [htlc]) + self.assertEqual(list(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED).values()), [htlc]) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, alice_channel.get_oldest_unrevoked_ctn(REMOTE)), [htlc]) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc]) @@ -693,9 +693,9 @@ class TestChanReserve(unittest.TestCase): force_state_transition(self.alice_channel, self.bob_channel) aliceSelfBalance = self.alice_channel.balance(LOCAL)\ - - lnchannel.htlcsum(self.alice_channel.hm.htlcs_by_direction(LOCAL, SENT)) + - lnchannel.htlcsum(self.alice_channel.hm.htlcs_by_direction(LOCAL, SENT).values()) bobBalance = self.bob_channel.balance(REMOTE)\ - - lnchannel.htlcsum(self.alice_channel.hm.htlcs_by_direction(REMOTE, SENT)) + - lnchannel.htlcsum(self.alice_channel.hm.htlcs_by_direction(REMOTE, SENT).values()) self.assertEqual(aliceSelfBalance, one_bitcoin_in_msat*4.5) self.assertEqual(bobBalance, one_bitcoin_in_msat*5) # Now let Bob try to add an HTLC. This should fail, since it will diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py index c36587c5d..f897e4df2 100644 --- a/electrum/tests/test_lnhtlc.py +++ b/electrum/tests/test_lnhtlc.py @@ -82,7 +82,7 @@ class TestHTLCManager(unittest.TestCase): else: B.send_fail(0) A.recv_fail(0) - self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)]) + self.assertEqual(list(A.htlcs_by_direction(REMOTE, RECEIVED).values()), [H('A', 0)]) self.assertNotEqual(A.get_htlcs_in_latest_ctx(LOCAL), []) self.assertNotEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])