Browse Source

lnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails)

regtest_lnd
Janus 6 years ago
committed by SomberNight
parent
commit
64f6a0c43f
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 101
      electrum/lnchan.py
  2. 6
      electrum/tests/test_lnchan.py

101
electrum/lnchan.py

@ -26,7 +26,7 @@ from collections import namedtuple, defaultdict
import binascii import binascii
import json import json
from enum import Enum, auto from enum import Enum, auto
from typing import Optional, Dict, List, Tuple from typing import Optional, Dict, List, Tuple, NamedTuple, Set
from copy import deepcopy from copy import deepcopy
from .util import bfh, PrintError, bh2u from .util import bfh, PrintError, bh2u
@ -121,6 +121,20 @@ def str_bytes_dict_from_save(x):
def str_bytes_dict_to_save(x): def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()} return {str(k): bh2u(v) for k, v in x.items()}
class HtlcChanges(NamedTuple):
# ints are htlc ids
adds: Dict[int, UpdateAddHtlc]
settles: Set[int]
fails: Set[int]
locked_in: Set[int]
@staticmethod
def new():
"""
Since we can't use default arguments for these types (they would be shared among instances)
"""
return HtlcChanges({}, set(), set(), set())
class Channel(PrintError): class Channel(PrintError):
def diagnostic_name(self): def diagnostic_name(self):
if self.name: if self.name:
@ -158,18 +172,12 @@ class Channel(PrintError):
# any past commitment transaction and use that instead; until then... # any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"]) self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
template = lambda: { self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()}
'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
'settles': [], # List[HTLC_ID]
'fails': [], # List[HTLC_ID]
'locked_in': [], # List[HTLC_ID]
}
self.log = {LOCAL: template(), REMOTE: template()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]: for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue if strname not in state: continue
for y in state[strname]: for y in state[strname]:
htlc = UpdateAddHtlc(**y) htlc = UpdateAddHtlc(**y)
self.log[subject]['adds'][htlc.htlc_id] = htlc self.log[subject].adds[htlc.htlc_id] = htlc
self.name = name self.name = name
@ -185,6 +193,9 @@ class Channel(PrintError):
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])} self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
for sub in (LOCAL, REMOTE):
self.log[sub].locked_in.update(self.log[sub].adds.keys())
def set_state(self, state: str): def set_state(self, state: str):
self._state = state self._state = state
@ -232,7 +243,7 @@ class Channel(PrintError):
assert type(htlc) is dict assert type(htlc) is dict
self._check_can_pay(htlc['amount_msat']) self._check_can_pay(htlc['amount_msat'])
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id) htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
self.log[LOCAL]['adds'][htlc.htlc_id] = htlc self.log[LOCAL].adds[htlc.htlc_id] = htlc
self.print_error("add_htlc") self.print_error("add_htlc")
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1) self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id return htlc.htlc_id
@ -251,7 +262,7 @@ class Channel(PrintError):
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\ raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
f' Available at remote: {self.available_to_spend(REMOTE)},' +\ f' Available at remote: {self.available_to_spend(REMOTE)},' +\
f' HTLC amount: {htlc.amount_msat}') f' HTLC amount: {htlc.amount_msat}')
adds = self.log[REMOTE]['adds'] adds = self.log[REMOTE].adds
adds[htlc.htlc_id] = htlc adds[htlc.htlc_id] = htlc
self.print_error("receive_htlc") self.print_error("receive_htlc")
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1) self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
@ -309,11 +320,11 @@ class Channel(PrintError):
for sub in (LOCAL, REMOTE): for sub in (LOCAL, REMOTE):
log = self.log[sub] log = self.log[sub]
yield (sub, deepcopy(log)) yield (sub, deepcopy(log))
for htlc_id in log['fails']: for htlc_id in log.fails:
log['adds'].pop(htlc_id) log.adds.pop(htlc_id)
log['fails'].clear() log.fails.clear()
self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys() self.log[subject].locked_in.update(self.log[subject].adds.keys())
def receive_new_commitment(self, sig, htlc_sigs): def receive_new_commitment(self, sig, htlc_sigs):
""" """
@ -474,11 +485,11 @@ class Channel(PrintError):
""" """
old_amount = htlcsum(self.htlcs(subject, False)) old_amount = htlcsum(self.htlcs(subject, False))
for htlc_id in self.log[subject]['settles']: for htlc_id in self.log[subject].settles:
adds = self.log[subject]['adds'] adds = self.log[subject].adds
htlc = adds.pop(htlc_id) htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat) self.settled[subject].append(htlc.amount_msat)
self.log[subject]['settles'].clear() self.log[subject].settles.clear()
return old_amount - htlcsum(self.htlcs(subject, False)) return old_amount - htlcsum(self.htlcs(subject, False))
@ -533,7 +544,7 @@ class Channel(PrintError):
pending outgoing HTLCs, is used in the UI. pending outgoing HTLCs, is used in the UI.
""" """
return self.balance(subject)\ return self.balance(subject)\
- htlcsum(self.log[subject]['adds'].values()) - htlcsum(self.log[subject].adds.values())
def available_to_spend(self, subject): def available_to_spend(self, subject):
""" """
@ -541,7 +552,7 @@ class Channel(PrintError):
not be used in the UI cause it fluctuates (commit fee) not be used in the UI cause it fluctuates (commit fee)
""" """
return self.balance_minus_outgoing_htlcs(subject)\ return self.balance_minus_outgoing_htlcs(subject)\
- htlcsum(self.log[subject]['adds'].values())\ - htlcsum(self.log[subject].adds.values())\
- self.config[-subject].reserve_sat * 1000\ - self.config[-subject].reserve_sat * 1000\
- calc_onchain_fees( - calc_onchain_fees(
# TODO should we include a potential new htlc, when we are called from receive_htlc? # TODO should we include a potential new htlc, when we are called from receive_htlc?
@ -601,10 +612,10 @@ class Channel(PrintError):
""" """
update_log = self.log[subject] update_log = self.log[subject]
res = [] res = []
for htlc in update_log['adds'].values(): for htlc in update_log.adds.values():
locked_in = htlc.htlc_id in update_log['locked_in'] locked_in = htlc.htlc_id in update_log.locked_in
settled = htlc.htlc_id in update_log['settles'] settled = htlc.htlc_id in update_log.settles
failed = htlc.htlc_id in update_log['fails'] failed = htlc.htlc_id in update_log.fails
if not locked_in: if not locked_in:
continue continue
if only_pending == (settled or failed): if only_pending == (settled or failed):
@ -617,25 +628,33 @@ class Channel(PrintError):
SettleHTLC attempts to settle an existing outstanding received HTLC. SettleHTLC attempts to settle an existing outstanding received HTLC.
""" """
self.print_error("settle_htlc") self.print_error("settle_htlc")
htlc = self.log[REMOTE]['adds'][htlc_id] log = self.log[REMOTE]
htlc = log.adds[htlc_id]
assert htlc.payment_hash == sha256(preimage) assert htlc.payment_hash == sha256(preimage)
self.log[REMOTE]['settles'].append(htlc_id) assert htlc_id not in log.settles
log.settles.add(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices # not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id): def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle") self.print_error("receive_htlc_settle")
htlc = self.log[LOCAL]['adds'][htlc_id] log = self.log[LOCAL]
htlc = log.adds[htlc_id]
assert htlc.payment_hash == sha256(preimage) assert htlc.payment_hash == sha256(preimage)
self.log[LOCAL]['settles'].append(htlc_id) assert htlc_id not in log.settles
log.settles.add(htlc_id)
# we don't save the preimage because we don't need to forward it anyway # we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id): def fail_htlc(self, htlc_id):
self.print_error("fail_htlc") self.print_error("fail_htlc")
self.log[REMOTE]['fails'].append(htlc_id) log = self.log[REMOTE]
assert htlc_id not in log.fails
log.fails.add(htlc_id)
def receive_fail_htlc(self, htlc_id): def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc") self.print_error("receive_fail_htlc")
self.log[LOCAL]['fails'].append(htlc_id) log = self.log[LOCAL]
assert htlc_id not in log.fails
log.fails.add(htlc_id)
@property @property
def current_height(self): def current_height(self):
@ -666,8 +685,8 @@ class Channel(PrintError):
removed = [] removed = []
htlcs = [] htlcs = []
log = self.log[subject] log = self.log[subject]
for htlc_id, i in log['adds'].items(): for i in log.adds.values():
locked_in = htlc_id in log['locked_in'] locked_in = i.htlc_id in log.locked_in
if locked_in: if locked_in:
htlcs.append(i._asdict()) htlcs.append(i._asdict())
else: else:
@ -710,18 +729,26 @@ class Channel(PrintError):
def serialize(self): def serialize(self):
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()} namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()} serialized_channel = {}
to_save_ref = self.to_save()
for k, v in to_save_ref.items():
if isinstance(v, tuple):
serialized_channel[k] = namedtuples_to_dict(v)
else:
serialized_channel[k] = v
dumped = ChannelJsonEncoder().encode(serialized_channel) dumped = ChannelJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped) roundtripped = json.loads(dumped)
reconstructed = Channel(roundtripped) reconstructed = Channel(roundtripped)
if reconstructed.to_save() != self.to_save(): to_save_new = reconstructed.to_save()
from pprint import pformat if to_save_new != to_save_ref:
from pprint import PrettyPrinter
pp = PrettyPrinter(indent=168)
try: try:
from deepdiff import DeepDiff from deepdiff import DeepDiff
except ImportError: except ImportError:
raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(reconstructed.to_save()) + "\n" + pformat(self.to_save())) raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
else: else:
raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(DeepDiff(reconstructed.to_save(), self.to_save()))) raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
return roundtripped return roundtripped
def __str__(self): def __str__(self):

6
electrum/tests/test_lnchan.py

@ -183,7 +183,7 @@ class TestChannel(unittest.TestCase):
self.bob_pending_remote_balance = after self.bob_pending_remote_balance = after
self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0] self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
def test_SimpleAddSettleWorkflow(self): def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel alice_channel, bob_channel = self.alice_channel, self.bob_channel
@ -217,6 +217,10 @@ class TestChannel(unittest.TestCase):
# forward since she's sending an outgoing HTLC. # forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation) alice_channel.receive_revocation(bobRevocation)
# test serializing with locked_in htlc
self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
alice_channel.serialize()
# Alice then processes bob's signature, and since she just received # Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to # the revocation, she expect this signature to cover everything up to
# the point where she sent her signature, including the HTLC. # the point where she sent her signature, including the HTLC.

Loading…
Cancel
Save