From 1d4c113a3524cf012d5a7c40acdc7a61e16b19ed Mon Sep 17 00:00:00 2001 From: Janus Date: Wed, 10 Oct 2018 19:52:46 +0200 Subject: [PATCH] lnhtlc: remove lookup_htlc, use heterogeneously typed lists --- electrum/lnbase.py | 1 - electrum/lnhtlc.py | 91 +++++++++++------------------------ electrum/tests/test_lnhtlc.py | 2 +- 3 files changed, 28 insertions(+), 66 deletions(-) diff --git a/electrum/lnbase.py b/electrum/lnbase.py index bdd3c8398..53062a16e 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -1178,7 +1178,6 @@ class Peer(PrintError): chan = self.channels[update_fulfill_htlc_msg["channel_id"]] preimage = update_fulfill_htlc_msg["payment_preimage"] htlc_id = int.from_bytes(update_fulfill_htlc_msg["id"], "big") - htlc = chan.lookup_htlc(chan.log[LOCAL], htlc_id) chan.receive_htlc_settle(preimage, htlc_id) await self.receive_commitment(chan) self.revoke(chan) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 1c56e755d..e845b1e9f 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -23,8 +23,6 @@ from .transaction import Transaction, TxOutput, construct_witness from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE -FailHtlc = namedtuple("FailHtlc", ["htlc_id"]) -SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"]) RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"]) class FeeUpdateProgress(Enum): @@ -100,14 +98,6 @@ def typeWrap(k, v, local): return v class HTLCStateMachine(PrintError): - def lookup_htlc(self, log, htlc_id): - assert type(htlc_id) is int - for htlc in log: - if type(htlc) is not UpdateAddHtlc: continue - if htlc.htlc_id == htlc_id: - return htlc - assert False, self.diagnostic_name() + ": htlc_id {} not found in {}".format(htlc_id, log) - def diagnostic_name(self): return str(self.name) @@ -146,18 +136,13 @@ class HTLCStateMachine(PrintError): # any past commitment transaction and use that instead; until then... self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"]) - self.log = {LOCAL: [], REMOTE: []} + template = lambda: {'adds': {}, 'settles': []} + self.log = {LOCAL: template(), REMOTE: template()} for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]: if strname not in state: continue - for typ,y in state[strname]: - if typ == "UpdateAddHtlc": - self.log[subject].append(UpdateAddHtlc(*decodeAll(y))) - elif typ == "SettleHtlc": - self.log[subject].append(SettleHtlc(*decodeAll(y))) - elif typ == "FailHtlc": - self.log[subject].append(FailHtlc(*decodeAll(y))) - else: - assert False + for y in state[strname]: + htlc = UpdateAddHtlc(*decodeAll(y)) + self.log[subject]['adds'][htlc.htlc_id] = htlc self.name = name @@ -197,7 +182,7 @@ class HTLCStateMachine(PrintError): """ assert type(htlc) is dict htlc = UpdateAddHtlc(**htlc, htlc_id=self.local_state.next_htlc_id) - self.log[LOCAL].append(htlc) + self.log[LOCAL]['adds'][htlc.htlc_id] = htlc self.print_error("add_htlc") self.local_state=self.local_state._replace(next_htlc_id=htlc.htlc_id + 1) return htlc.htlc_id @@ -210,7 +195,7 @@ class HTLCStateMachine(PrintError): """ assert type(htlc) is dict htlc = UpdateAddHtlc(**htlc, htlc_id = self.remote_state.next_htlc_id) - self.log[REMOTE].append(htlc) + self.log[REMOTE]['adds'][htlc.htlc_id] = htlc self.print_error("receive_htlc") self.remote_state=self.remote_state._replace(next_htlc_id=htlc.htlc_id + 1) return htlc.htlc_id @@ -228,9 +213,8 @@ class HTLCStateMachine(PrintError): any). The HTLC signatures are sorted according to the BIP 69 order of the HTLC's on the commitment transaction. """ - for htlc in self.log[LOCAL]: - if not type(htlc) is UpdateAddHtlc: continue - if htlc.locked_in[LOCAL] is None and FailHtlc(htlc.htlc_id) not in self.log[REMOTE]: + for htlc in self.log[LOCAL]['adds'].values(): + if htlc.locked_in[LOCAL] is None: htlc.locked_in[LOCAL] = self.local_state.ctn self.print_error("sign_next_commitment") @@ -279,9 +263,8 @@ class HTLCStateMachine(PrintError): """ self.print_error("receive_new_commitment") - for htlc in self.log[REMOTE]: - if not type(htlc) is UpdateAddHtlc: continue - if htlc.locked_in[REMOTE] is None and FailHtlc(htlc.htlc_id) not in self.log[LOCAL]: + for htlc in self.log[REMOTE]['adds'].values(): + if htlc.locked_in[REMOTE] is None: htlc.locked_in[REMOTE] = self.remote_state.ctn assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes @@ -422,18 +405,15 @@ class HTLCStateMachine(PrintError): def mark_settled(subject): """ - find settled htlcs for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs + find pending settlements for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs """ old_amount = self.htlcsum(self.gen_htlc_indices(subject, False)) - removed = [] - for x in self.log[-subject]: - if type(x) is not SettleHtlc: continue - htlc = self.lookup_htlc(self.log[subject], x.htlc_id) + for htlc_id in self.log[-subject]['settles']: + adds = self.log[subject]['adds'] + htlc = adds.pop(htlc_id) self.settled[subject].append(htlc.amount_msat) - self.log[subject].remove(htlc) - removed.append(x) - for x in removed: self.log[-subject].remove(x) + self.log[-subject]['settles'].clear() return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False)) @@ -533,12 +513,10 @@ class HTLCStateMachine(PrintError): update_log = self.log[subject] other_log = self.log[-subject] res = [] - for htlc in update_log: - if type(htlc) is not UpdateAddHtlc: - continue + for htlc in update_log['adds'].values(): locked_in = htlc.locked_in[subject] - if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log): + if locked_in is None or only_pending == (htlc.htlc_id in other_log['settles']): continue res.append(htlc) return res @@ -558,23 +536,19 @@ class HTLCStateMachine(PrintError): SettleHTLC attempts to settle an existing outstanding received HTLC. """ self.print_error("settle_htlc") - htlc = self.lookup_htlc(self.log[REMOTE], htlc_id) + htlc = self.log[REMOTE]['adds'][htlc_id] assert htlc.payment_hash == sha256(preimage) - self.log[LOCAL].append(SettleHtlc(htlc_id)) + self.log[LOCAL]['settles'].append(htlc_id) def receive_htlc_settle(self, preimage, htlc_index): self.print_error("receive_htlc_settle") - htlc = self.lookup_htlc(self.log[LOCAL], htlc_index) + htlc = self.log[LOCAL]['adds'][htlc_index] assert htlc.payment_hash == sha256(preimage) - assert len([x for x in self.log[LOCAL] if x.htlc_id == htlc_index and type(x) is UpdateAddHtlc]) == 1, (self.log[LOCAL], htlc_index) - self.log[REMOTE].append(SettleHtlc(htlc_index)) + self.log[REMOTE]['settles'].append(htlc_index) def receive_fail_htlc(self, htlc_id): self.print_error("receive_fail_htlc") - htlc = self.lookup_htlc(self.log[LOCAL], htlc_id) - htlc.locked_in[LOCAL] = None - htlc.locked_in[REMOTE] = None - self.log[REMOTE].append(FailHtlc(htlc_id)) + self.log[LOCAL]['adds'].pop(htlc_id) @property def current_height(self): @@ -604,14 +578,9 @@ class HTLCStateMachine(PrintError): """ removed = [] htlcs = [] - for i in self.log[subject]: - if type(i) is not UpdateAddHtlc: - htlcs.append(i) - continue - settled = SettleHtlc(i.htlc_id) in self.log[-subject] - failed = FailHtlc(i.htlc_id) in self.log[-subject] + for i in self.log[subject]['adds'].values(): locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None - if locked_in or settled or failed: + if locked_in: htlcs.append(i) else: removed.append(i.htlc_id) @@ -634,8 +603,8 @@ class HTLCStateMachine(PrintError): "funding_outpoint": self.funding_outpoint, "node_id": self.node_id, "remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked), - "remote_log": [(type(x).__name__, x) for x in remote_filtered], - "local_log": [(type(x).__name__, x) for x in local_filtered], + "remote_log": remote_filtered, + "local_log": local_filtered, "onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()}, "settled_local": self.settled[LOCAL], "settled_remote": self.settled[REMOTE], @@ -662,12 +631,6 @@ class HTLCStateMachine(PrintError): return binascii.hexlify(o).decode("ascii") if isinstance(o, RevocationStore): return o.serialize() - if isinstance(o, SettleHtlc): - return json.dumps(('SettleHtlc', namedtuples_to_dict(o))) - if isinstance(o, FailHtlc): - return json.dumps(('FailHtlc', namedtuples_to_dict(o))) - if isinstance(o, UpdateAddHtlc): - return json.dumps(('UpdateAddHtlc', namedtuples_to_dict(o))) return super(MyJsonEncoder, self) dumped = MyJsonEncoder().encode(serialized_channel) roundtripped = json.loads(dumped) diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py index 2a45007e7..8c0e7a521 100644 --- a/electrum/tests/test_lnhtlc.py +++ b/electrum/tests/test_lnhtlc.py @@ -147,7 +147,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc) self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc) - self.htlc = self.bob_channel.log[lnutil.REMOTE][0] + self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0] def test_SimpleAddSettleWorkflow(self): alice_channel, bob_channel = self.alice_channel, self.bob_channel