Browse Source

lnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails)

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
39fa13b938
  1. 101
      electrum/lnchan.py
  2. 6
      electrum/tests/test_lnchan.py

101
electrum/lnchan.py

@ -26,7 +26,7 @@ from collections import namedtuple, defaultdict
import binascii
import json
from enum import Enum, auto
from typing import Optional, Dict, List, Tuple
from typing import Optional, Dict, List, Tuple, NamedTuple, Set
from copy import deepcopy
from .util import bfh, PrintError, bh2u
@ -121,6 +121,20 @@ def str_bytes_dict_from_save(x):
def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()}
class HtlcChanges(NamedTuple):
# ints are htlc ids
adds: Dict[int, UpdateAddHtlc]
settles: Set[int]
fails: Set[int]
locked_in: Set[int]
@staticmethod
def new():
"""
Since we can't use default arguments for these types (they would be shared among instances)
"""
return HtlcChanges({}, set(), set(), set())
class Channel(PrintError):
def diagnostic_name(self):
if self.name:
@ -158,18 +172,12 @@ class Channel(PrintError):
# any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
template = lambda: {
'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
'settles': [], # List[HTLC_ID]
'fails': [], # List[HTLC_ID]
'locked_in': [], # List[HTLC_ID]
}
self.log = {LOCAL: template(), REMOTE: template()}
self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue
for y in state[strname]:
htlc = UpdateAddHtlc(**y)
self.log[subject]['adds'][htlc.htlc_id] = htlc
self.log[subject].adds[htlc.htlc_id] = htlc
self.name = name
@ -185,6 +193,9 @@ class Channel(PrintError):
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
for sub in (LOCAL, REMOTE):
self.log[sub].locked_in.update(self.log[sub].adds.keys())
def set_state(self, state: str):
self._state = state
@ -232,7 +243,7 @@ class Channel(PrintError):
assert type(htlc) is dict
self._check_can_pay(htlc['amount_msat'])
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
self.log[LOCAL]['adds'][htlc.htlc_id] = htlc
self.log[LOCAL].adds[htlc.htlc_id] = htlc
self.print_error("add_htlc")
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@ -251,7 +262,7 @@ class Channel(PrintError):
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
f' HTLC amount: {htlc.amount_msat}')
adds = self.log[REMOTE]['adds']
adds = self.log[REMOTE].adds
adds[htlc.htlc_id] = htlc
self.print_error("receive_htlc")
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
@ -309,11 +320,11 @@ class Channel(PrintError):
for sub in (LOCAL, REMOTE):
log = self.log[sub]
yield (sub, deepcopy(log))
for htlc_id in log['fails']:
log['adds'].pop(htlc_id)
log['fails'].clear()
for htlc_id in log.fails:
log.adds.pop(htlc_id)
log.fails.clear()
self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
self.log[subject].locked_in.update(self.log[subject].adds.keys())
def receive_new_commitment(self, sig, htlc_sigs):
"""
@ -474,11 +485,11 @@ class Channel(PrintError):
"""
old_amount = htlcsum(self.htlcs(subject, False))
for htlc_id in self.log[subject]['settles']:
adds = self.log[subject]['adds']
for htlc_id in self.log[subject].settles:
adds = self.log[subject].adds
htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat)
self.log[subject]['settles'].clear()
self.log[subject].settles.clear()
return old_amount - htlcsum(self.htlcs(subject, False))
@ -533,7 +544,7 @@ class Channel(PrintError):
pending outgoing HTLCs, is used in the UI.
"""
return self.balance(subject)\
- htlcsum(self.log[subject]['adds'].values())
- htlcsum(self.log[subject].adds.values())
def available_to_spend(self, subject):
"""
@ -541,7 +552,7 @@ class Channel(PrintError):
not be used in the UI cause it fluctuates (commit fee)
"""
return self.balance_minus_outgoing_htlcs(subject)\
- htlcsum(self.log[subject]['adds'].values())\
- htlcsum(self.log[subject].adds.values())\
- self.config[-subject].reserve_sat * 1000\
- calc_onchain_fees(
# TODO should we include a potential new htlc, when we are called from receive_htlc?
@ -601,10 +612,10 @@ class Channel(PrintError):
"""
update_log = self.log[subject]
res = []
for htlc in update_log['adds'].values():
locked_in = htlc.htlc_id in update_log['locked_in']
settled = htlc.htlc_id in update_log['settles']
failed = htlc.htlc_id in update_log['fails']
for htlc in update_log.adds.values():
locked_in = htlc.htlc_id in update_log.locked_in
settled = htlc.htlc_id in update_log.settles
failed = htlc.htlc_id in update_log.fails
if not locked_in:
continue
if only_pending == (settled or failed):
@ -617,25 +628,33 @@ class Channel(PrintError):
SettleHTLC attempts to settle an existing outstanding received HTLC.
"""
self.print_error("settle_htlc")
htlc = self.log[REMOTE]['adds'][htlc_id]
log = self.log[REMOTE]
htlc = log.adds[htlc_id]
assert htlc.payment_hash == sha256(preimage)
self.log[REMOTE]['settles'].append(htlc_id)
assert htlc_id not in log.settles
log.settles.add(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle")
htlc = self.log[LOCAL]['adds'][htlc_id]
log = self.log[LOCAL]
htlc = log.adds[htlc_id]
assert htlc.payment_hash == sha256(preimage)
self.log[LOCAL]['settles'].append(htlc_id)
assert htlc_id not in log.settles
log.settles.add(htlc_id)
# we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id):
self.print_error("fail_htlc")
self.log[REMOTE]['fails'].append(htlc_id)
log = self.log[REMOTE]
assert htlc_id not in log.fails
log.fails.add(htlc_id)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
self.log[LOCAL]['fails'].append(htlc_id)
log = self.log[LOCAL]
assert htlc_id not in log.fails
log.fails.add(htlc_id)
@property
def current_height(self):
@ -666,8 +685,8 @@ class Channel(PrintError):
removed = []
htlcs = []
log = self.log[subject]
for htlc_id, i in log['adds'].items():
locked_in = htlc_id in log['locked_in']
for i in log.adds.values():
locked_in = i.htlc_id in log.locked_in
if locked_in:
htlcs.append(i._asdict())
else:
@ -710,18 +729,26 @@ class Channel(PrintError):
def serialize(self):
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()}
serialized_channel = {}
to_save_ref = self.to_save()
for k, v in to_save_ref.items():
if isinstance(v, tuple):
serialized_channel[k] = namedtuples_to_dict(v)
else:
serialized_channel[k] = v
dumped = ChannelJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)
reconstructed = Channel(roundtripped)
if reconstructed.to_save() != self.to_save():
from pprint import pformat
to_save_new = reconstructed.to_save()
if to_save_new != to_save_ref:
from pprint import PrettyPrinter
pp = PrettyPrinter(indent=168)
try:
from deepdiff import DeepDiff
except ImportError:
raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(reconstructed.to_save()) + "\n" + pformat(self.to_save()))
raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
else:
raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(DeepDiff(reconstructed.to_save(), self.to_save())))
raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
return roundtripped
def __str__(self):

6
electrum/tests/test_lnchan.py

@ -183,7 +183,7 @@ class TestChannel(unittest.TestCase):
self.bob_pending_remote_balance = after
self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0]
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel
@ -217,6 +217,10 @@ class TestChannel(unittest.TestCase):
# forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation)
# test serializing with locked_in htlc
self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
alice_channel.serialize()
# Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to
# the point where she sent her signature, including the HTLC.

Loading…
Cancel
Save