Browse Source

lnhtlc: remove lookup_htlc, use heterogeneously typed lists

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
1d4c113a35
  1. 1
      electrum/lnbase.py
  2. 91
      electrum/lnhtlc.py
  3. 2
      electrum/tests/test_lnhtlc.py

1
electrum/lnbase.py

@ -1178,7 +1178,6 @@ class Peer(PrintError):
chan = self.channels[update_fulfill_htlc_msg["channel_id"]]
preimage = update_fulfill_htlc_msg["payment_preimage"]
htlc_id = int.from_bytes(update_fulfill_htlc_msg["id"], "big")
htlc = chan.lookup_htlc(chan.log[LOCAL], htlc_id)
chan.receive_htlc_settle(preimage, htlc_id)
await self.receive_commitment(chan)
self.revoke(chan)

91
electrum/lnhtlc.py

@ -23,8 +23,6 @@ from .transaction import Transaction, TxOutput, construct_witness
from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE
FailHtlc = namedtuple("FailHtlc", ["htlc_id"])
SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"])
RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
class FeeUpdateProgress(Enum):
@ -100,14 +98,6 @@ def typeWrap(k, v, local):
return v
class HTLCStateMachine(PrintError):
def lookup_htlc(self, log, htlc_id):
assert type(htlc_id) is int
for htlc in log:
if type(htlc) is not UpdateAddHtlc: continue
if htlc.htlc_id == htlc_id:
return htlc
assert False, self.diagnostic_name() + ": htlc_id {} not found in {}".format(htlc_id, log)
def diagnostic_name(self):
return str(self.name)
@ -146,18 +136,13 @@ class HTLCStateMachine(PrintError):
# any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
self.log = {LOCAL: [], REMOTE: []}
template = lambda: {'adds': {}, 'settles': []}
self.log = {LOCAL: template(), REMOTE: template()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue
for typ,y in state[strname]:
if typ == "UpdateAddHtlc":
self.log[subject].append(UpdateAddHtlc(*decodeAll(y)))
elif typ == "SettleHtlc":
self.log[subject].append(SettleHtlc(*decodeAll(y)))
elif typ == "FailHtlc":
self.log[subject].append(FailHtlc(*decodeAll(y)))
else:
assert False
for y in state[strname]:
htlc = UpdateAddHtlc(*decodeAll(y))
self.log[subject]['adds'][htlc.htlc_id] = htlc
self.name = name
@ -197,7 +182,7 @@ class HTLCStateMachine(PrintError):
"""
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id=self.local_state.next_htlc_id)
self.log[LOCAL].append(htlc)
self.log[LOCAL]['adds'][htlc.htlc_id] = htlc
self.print_error("add_htlc")
self.local_state=self.local_state._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@ -210,7 +195,7 @@ class HTLCStateMachine(PrintError):
"""
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id = self.remote_state.next_htlc_id)
self.log[REMOTE].append(htlc)
self.log[REMOTE]['adds'][htlc.htlc_id] = htlc
self.print_error("receive_htlc")
self.remote_state=self.remote_state._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@ -228,9 +213,8 @@ class HTLCStateMachine(PrintError):
any). The HTLC signatures are sorted according to the BIP 69 order of the
HTLC's on the commitment transaction.
"""
for htlc in self.log[LOCAL]:
if not type(htlc) is UpdateAddHtlc: continue
if htlc.locked_in[LOCAL] is None and FailHtlc(htlc.htlc_id) not in self.log[REMOTE]:
for htlc in self.log[LOCAL]['adds'].values():
if htlc.locked_in[LOCAL] is None:
htlc.locked_in[LOCAL] = self.local_state.ctn
self.print_error("sign_next_commitment")
@ -279,9 +263,8 @@ class HTLCStateMachine(PrintError):
"""
self.print_error("receive_new_commitment")
for htlc in self.log[REMOTE]:
if not type(htlc) is UpdateAddHtlc: continue
if htlc.locked_in[REMOTE] is None and FailHtlc(htlc.htlc_id) not in self.log[LOCAL]:
for htlc in self.log[REMOTE]['adds'].values():
if htlc.locked_in[REMOTE] is None:
htlc.locked_in[REMOTE] = self.remote_state.ctn
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
@ -422,18 +405,15 @@ class HTLCStateMachine(PrintError):
def mark_settled(subject):
"""
find settled htlcs for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs
find pending settlements for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs
"""
old_amount = self.htlcsum(self.gen_htlc_indices(subject, False))
removed = []
for x in self.log[-subject]:
if type(x) is not SettleHtlc: continue
htlc = self.lookup_htlc(self.log[subject], x.htlc_id)
for htlc_id in self.log[-subject]['settles']:
adds = self.log[subject]['adds']
htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat)
self.log[subject].remove(htlc)
removed.append(x)
for x in removed: self.log[-subject].remove(x)
self.log[-subject]['settles'].clear()
return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False))
@ -533,12 +513,10 @@ class HTLCStateMachine(PrintError):
update_log = self.log[subject]
other_log = self.log[-subject]
res = []
for htlc in update_log:
if type(htlc) is not UpdateAddHtlc:
continue
for htlc in update_log['adds'].values():
locked_in = htlc.locked_in[subject]
if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log):
if locked_in is None or only_pending == (htlc.htlc_id in other_log['settles']):
continue
res.append(htlc)
return res
@ -558,23 +536,19 @@ class HTLCStateMachine(PrintError):
SettleHTLC attempts to settle an existing outstanding received HTLC.
"""
self.print_error("settle_htlc")
htlc = self.lookup_htlc(self.log[REMOTE], htlc_id)
htlc = self.log[REMOTE]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
self.log[LOCAL].append(SettleHtlc(htlc_id))
self.log[LOCAL]['settles'].append(htlc_id)
def receive_htlc_settle(self, preimage, htlc_index):
self.print_error("receive_htlc_settle")
htlc = self.lookup_htlc(self.log[LOCAL], htlc_index)
htlc = self.log[LOCAL]['adds'][htlc_index]
assert htlc.payment_hash == sha256(preimage)
assert len([x for x in self.log[LOCAL] if x.htlc_id == htlc_index and type(x) is UpdateAddHtlc]) == 1, (self.log[LOCAL], htlc_index)
self.log[REMOTE].append(SettleHtlc(htlc_index))
self.log[REMOTE]['settles'].append(htlc_index)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
htlc = self.lookup_htlc(self.log[LOCAL], htlc_id)
htlc.locked_in[LOCAL] = None
htlc.locked_in[REMOTE] = None
self.log[REMOTE].append(FailHtlc(htlc_id))
self.log[LOCAL]['adds'].pop(htlc_id)
@property
def current_height(self):
@ -604,14 +578,9 @@ class HTLCStateMachine(PrintError):
"""
removed = []
htlcs = []
for i in self.log[subject]:
if type(i) is not UpdateAddHtlc:
htlcs.append(i)
continue
settled = SettleHtlc(i.htlc_id) in self.log[-subject]
failed = FailHtlc(i.htlc_id) in self.log[-subject]
for i in self.log[subject]['adds'].values():
locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None
if locked_in or settled or failed:
if locked_in:
htlcs.append(i)
else:
removed.append(i.htlc_id)
@ -634,8 +603,8 @@ class HTLCStateMachine(PrintError):
"funding_outpoint": self.funding_outpoint,
"node_id": self.node_id,
"remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked),
"remote_log": [(type(x).__name__, x) for x in remote_filtered],
"local_log": [(type(x).__name__, x) for x in local_filtered],
"remote_log": remote_filtered,
"local_log": local_filtered,
"onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()},
"settled_local": self.settled[LOCAL],
"settled_remote": self.settled[REMOTE],
@ -662,12 +631,6 @@ class HTLCStateMachine(PrintError):
return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
if isinstance(o, SettleHtlc):
return json.dumps(('SettleHtlc', namedtuples_to_dict(o)))
if isinstance(o, FailHtlc):
return json.dumps(('FailHtlc', namedtuples_to_dict(o)))
if isinstance(o, UpdateAddHtlc):
return json.dumps(('UpdateAddHtlc', namedtuples_to_dict(o)))
return super(MyJsonEncoder, self)
dumped = MyJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)

2
electrum/tests/test_lnhtlc.py

@ -147,7 +147,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc)
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc)
self.htlc = self.bob_channel.log[lnutil.REMOTE][0]
self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0]
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel

Loading…
Cancel
Save