Browse Source

lnhtlc: (fix) was locking in too many updates during commit/revoke

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 5 years ago
committed by ThomasV
parent
commit
7431aac5cd
  1. 10
      electrum/lnchannel.py
  2. 106
      electrum/lnhtlc.py
  3. 10
      electrum/lnpeer.py
  4. 17
      electrum/tests/test_lnchannel.py
  5. 153
      electrum/tests/test_lnhtlc.py

10
electrum/lnchannel.py

@ -113,6 +113,10 @@ def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()} return {str(k): bh2u(v) for k, v in x.items()}
class Channel(Logger): class Channel(Logger):
# note: try to avoid naming ctns/ctxs/etc as "current" and "pending".
# they are ambiguous. Use "oldest_unrevoked" or "latest" or "next".
# TODO enforce this ^
def diagnostic_name(self): def diagnostic_name(self):
if self.name: if self.name:
return str(self.name) return str(self.name)
@ -154,7 +158,9 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked.deserialize(True) self.remote_commitment_to_be_revoked.deserialize(True)
log = state.get('log') log = state.get('log')
self.hm = HTLCManager(self.config[LOCAL].ctn, self.config[REMOTE].ctn, log) self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn,
remote_ctn=self.config[REMOTE].ctn,
log=log)
self.name = name self.name = name
Logger.__init__(self) Logger.__init__(self)
@ -209,6 +215,7 @@ class Channel(Logger):
return self.force_closed or self.get_state() in ['CLOSED', 'CLOSING'] return self.force_closed or self.get_state() in ['CLOSED', 'CLOSING']
def _check_can_pay(self, amount_msat: int) -> None: def _check_can_pay(self, amount_msat: int) -> None:
# TODO check if this method uses correct ctns (should use "latest" + 1)
if self.is_closed(): if self.is_closed():
raise PaymentFailure('Channel closed') raise PaymentFailure('Channel closed')
if self.get_state() != 'OPEN': if self.get_state() != 'OPEN':
@ -525,6 +532,7 @@ class Channel(Logger):
not be used in the UI cause it fluctuates (commit fee) not be used in the UI cause it fluctuates (commit fee)
""" """
# FIXME whose balance? whose ctx? # FIXME whose balance? whose ctx?
# FIXME confusing/mixing ctns (should probably use latest_ctn + 1; not oldest_unrevoked + 1)
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
return self.balance_minus_outgoing_htlcs(subject, ctx_owner=subject)\ return self.balance_minus_outgoing_htlcs(subject, ctx_owner=subject)\
- self.config[-subject].reserve_sat * 1000\ - self.config[-subject].reserve_sat * 1000\

106
electrum/lnhtlc.py

@ -1,46 +1,45 @@
from copy import deepcopy from copy import deepcopy
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple, List
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u from .util import bh2u
class HTLCManager: class HTLCManager:
def __init__(self, local_ctn=0, remote_ctn=0, log=None): def __init__(self, *, local_ctn=0, remote_ctn=0, log=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub # self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn} self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
# ctx_pending[sub] is True iff sub has sent commitment_signed but did not receive revoke_and_ack # ctx_pending[sub] is True iff sub has received commitment_signed but did not send revoke_and_ack (sub has multiple unrevoked ctxs)
self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted? self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted?
# expect_sig[SENT/RECEIVED] is True iff HTLCs have been sent/received but the corresponding commitment_signed has not been received/sent
self.expect_sig = {SENT: False, RECEIVED: False}
if log is None: if log is None:
initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}} initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else: else:
assert type(log) is dict assert type(log) is dict
log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()} log = {HTLCOwner(int(sub)): action for sub, action in deepcopy(log).items()}
for sub in (LOCAL, REMOTE): for sub in (LOCAL, REMOTE):
log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()} log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()}
coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()} coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
# "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()} log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['settles'].items()} log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
log[sub]['fails'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['fails'].items()} log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
self.log = log self.log = log
def ctn_latest(self, sub): def ctn_latest(self, sub: HTLCOwner) -> int:
"""Return the ctn for the latest (newest that has a valid sig) ctx of sub""" """Return the ctn for the latest (newest that has a valid sig) ctx of sub"""
return self.ctn[sub] + int(self.ctx_pending[sub]) return self.ctn[sub] + int(self.ctx_pending[sub])
def to_save(self): def to_save(self):
x = deepcopy(self.log) log = deepcopy(self.log)
for sub in (LOCAL, REMOTE): for sub in (LOCAL, REMOTE):
# adds
d = {} d = {}
for htlc_id, htlc in x[sub]['adds'].items(): for htlc_id, htlc in log[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:] d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
x[sub]['adds'] = d log[sub]['adds'] = d
return x return log
def channel_open_finished(self): def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0} self.ctn = {LOCAL: 0, REMOTE: 0}
@ -48,53 +47,55 @@ class HTLCManager:
def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
htlc_id = htlc.htlc_id htlc_id = htlc.htlc_id
adds = self.log[LOCAL]['adds'] self.log[LOCAL]['adds'][htlc_id] = htlc
assert type(adds) is not str
adds[htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1} self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1}
self.expect_sig[SENT] = True
return htlc return htlc
def recv_htlc(self, htlc: UpdateAddHtlc) -> None: def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
htlc_id = htlc.htlc_id htlc_id = htlc.htlc_id
self.log[REMOTE]['adds'][htlc_id] = htlc self.log[REMOTE]['adds'][htlc_id] = htlc
l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None} self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None}
self.expect_sig[RECEIVED] = True
def send_settle(self, htlc_id: int) -> None:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_settle(self, htlc_id: int) -> None:
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def send_fail(self, htlc_id: int) -> None:
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def send_ctx(self) -> None: def send_ctx(self) -> None:
assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE]) assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE])
self.ctx_pending[REMOTE] = True self.ctx_pending[REMOTE] = True
for locked_in in self.log[REMOTE]['locked_in'].values():
if locked_in[REMOTE] is None:
locked_in[REMOTE] = self.ctn_latest(REMOTE)
self.expect_sig[SENT] = False
def recv_ctx(self) -> None: def recv_ctx(self) -> None:
assert self.ctn_latest(LOCAL) == self.ctn[LOCAL], (self.ctn_latest(LOCAL), self.ctn[LOCAL]) assert self.ctn_latest(LOCAL) == self.ctn[LOCAL], (self.ctn_latest(LOCAL), self.ctn[LOCAL])
self.ctx_pending[LOCAL] = True self.ctx_pending[LOCAL] = True
for locked_in in self.log[LOCAL]['locked_in'].values():
if locked_in[LOCAL] is None:
locked_in[LOCAL] = self.ctn_latest(LOCAL)
self.expect_sig[RECEIVED] = False
def send_rev(self) -> None: def send_rev(self) -> None:
self.ctn[LOCAL] += 1 self.ctn[LOCAL] += 1
self.ctx_pending[LOCAL] = False self.ctx_pending[LOCAL] = False
for ctns in self.log[REMOTE]['locked_in'].values():
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
for log_action in ('settles', 'fails'): for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[LOCAL][log_action].items(): for ctns in self.log[LOCAL][log_action].values():
if ctns[REMOTE] is None: if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
def recv_rev(self) -> None: def recv_rev(self) -> None:
self.ctn[REMOTE] += 1 self.ctn[REMOTE] += 1
self.ctx_pending[REMOTE] = False self.ctx_pending[REMOTE] = False
for htlc_id, ctns in self.log[LOCAL]['locked_in'].items(): for ctns in self.log[LOCAL]['locked_in'].values():
if ctns[LOCAL] is None: if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
#assert ctns[REMOTE] == self.ctn[REMOTE] # FIXME I don't think this assert is correct
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
for log_action in ('settles', 'fails'): for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[REMOTE][log_action].items(): for ctns in self.log[REMOTE][log_action].values():
if ctns[LOCAL] is None: if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
@ -113,13 +114,7 @@ class HTLCManager:
# party is the proposer of the HTLCs # party is the proposer of the HTLCs
party = subject if direction == SENT else subject.inverted() party = subject if direction == SENT else subject.inverted()
for htlc_id, ctns in self.log[party]['locked_in'].items(): for htlc_id, ctns in self.log[party]['locked_in'].items():
htlc_height = ctns[subject] if ctns[subject] is not None and ctns[subject] <= ctn:
if htlc_height is None:
expect_sig = self.expect_sig[RECEIVED if party != LOCAL else SENT]
include = not expect_sig and ctns[-subject] <= ctn
else:
include = htlc_height <= ctn
if include:
settles = self.log[party]['settles'] settles = self.log[party]['settles']
fails = self.log[party]['fails'] fails = self.log[party]['fails']
not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn
@ -138,23 +133,20 @@ class HTLCManager:
l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)] l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)]
return l return l
def current_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
"""Return the list of HTLCs in subject's oldest unrevoked ctx."""
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
ctn = self.ctn[subject] ctn = self.ctn[subject]
return self.htlcs(subject, ctn) return self.htlcs(subject, ctn)
def pending_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
"""Return the list of HTLCs in subject's next ctx (one after oldest unrevoked)."""
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
ctn = self.ctn[subject] + 1 ctn = self.ctn_latest(subject)
return self.htlcs(subject, ctn) return self.htlcs(subject, ctn)
def send_settle(self, htlc_id: int) -> None: def get_htlcs_in_next_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} assert type(subject) is HTLCOwner
ctn = self.ctn_latest(subject) + 1
def recv_settle(self, htlc_id: int) -> None: return self.htlcs(subject, ctn)
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction, def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]: ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -194,9 +186,3 @@ class HTLCManager:
return [self.log[LOCAL]['adds'][htlc_id] return [self.log[LOCAL]['adds'][htlc_id]
for htlc_id, ctns in self.log[LOCAL]['settles'].items() for htlc_id, ctns in self.log[LOCAL]['settles'].items()
if ctns[LOCAL] == ctn] if ctns[LOCAL] == ctn]
def send_fail(self, htlc_id: int) -> None:
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}

10
electrum/lnpeer.py

@ -1024,12 +1024,12 @@ class Peer(Logger):
def maybe_send_commitment(self, chan: Channel): def maybe_send_commitment(self, chan: Channel):
ctn_to_sign = chan.get_current_ctn(REMOTE) + 1 ctn_to_sign = chan.get_current_ctn(REMOTE) + 1
# if there are no changes, we will not (and must not) send a new commitment # if there are no changes, we will not (and must not) send a new commitment
pending, current = chan.hm.pending_htlcs(REMOTE), chan.hm.current_htlcs(REMOTE) next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE)
if (pending == current if (next_htlcs == latest_htlcs
and chan.pending_feerate(REMOTE) == chan.constraints.feerate) \ and chan.pending_feerate(REMOTE) == chan.constraints.feerate) \
or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]: or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]:
return return
self.logger.info(f'send_commitment. old number htlcs: {len(current)}, new number htlcs: {len(pending)}') self.logger.info(f'send_commitment. old number htlcs: {len(latest_htlcs)}, new number htlcs: {len(next_htlcs)}')
sig_64, htlc_sigs = chan.sign_next_commitment() sig_64, htlc_sigs = chan.sign_next_commitment()
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
self.sent_commitment_for_ctn_last[chan] = ctn_to_sign self.sent_commitment_for_ctn_last[chan] = ctn_to_sign
@ -1087,8 +1087,8 @@ class Peer(Logger):
channel_id = payload['channel_id'] channel_id = payload['channel_id']
chan = self.channels[channel_id] chan = self.channels[channel_id]
# make sure there were changes to the ctx, otherwise the remote peer is misbehaving # make sure there were changes to the ctx, otherwise the remote peer is misbehaving
if (chan.hm.pending_htlcs(LOCAL) == chan.hm.current_htlcs(LOCAL) if (chan.hm.get_htlcs_in_next_ctx(LOCAL) == chan.hm.get_htlcs_in_latest_ctx(LOCAL)
and chan.pending_feerate(LOCAL) == chan.constraints.feerate): and chan.pending_feerate(LOCAL) == chan.constraints.feerate):
raise RemoteMisbehaving('received commitment_signed without pending changes') raise RemoteMisbehaving('received commitment_signed without pending changes')
# make sure ctn is new # make sure ctn is new
ctn_to_recv = chan.get_current_ctn(LOCAL) + 1 ctn_to_recv = chan.get_current_ctn(LOCAL) + 1

17
electrum/tests/test_lnchannel.py

@ -226,6 +226,8 @@ class TestChannel(unittest.TestCase):
self.bob_channel.add_htlc(self.htlc_dict) self.bob_channel.add_htlc(self.htlc_dict)
self.alice_channel.receive_htlc(self.htlc_dict) self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3)
self.alice_channel.revoke_current_commitment()
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4) self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)
def test_SimpleAddSettleWorkflow(self): def test_SimpleAddSettleWorkflow(self):
@ -279,8 +281,8 @@ class TestChannel(unittest.TestCase):
self.assertTrue(alice_channel.signature_fits(com())) self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED) self.assertEqual(next(iter(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE)))[0], RECEIVED)
self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL)) self.assertEqual(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE), bob_channel.hm.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs()) self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs())
# Bob receives this signature message, and checks that this covers the # Bob receives this signature message, and checks that this covers the
@ -291,14 +293,11 @@ class TestChannel(unittest.TestCase):
self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL)))
self.assertEqual(bob_channel.config[REMOTE].ctn, 0) self.assertEqual(bob_channel.config[REMOTE].ctn, 0)
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc]) self.assertEqual(bob_channel.included_htlcs(LOCAL, RECEIVED, 1), [htlc])#
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), []) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), [])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc]) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc])
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 0), [])
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 0), []) self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 0), [])
self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 1), []) self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 1), [])
@ -323,7 +322,11 @@ class TestChannel(unittest.TestCase):
self.assertTrue(alice_channel.signature_fits(com())) self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3) # so far: Alice added htlc, Alice signed.
self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 2) # oldest unrevoked
self.assertEqual(len(alice_channel.pending_commitment(REMOTE).outputs()), 3) # latest
# Alice then processes this revocation, sending her own revocation for # Alice then processes this revocation, sending her own revocation for
# her prior commitment transaction. Alice shouldn't have any HTLCs to # her prior commitment transaction. Alice shouldn't have any HTLCs to

153
electrum/tests/test_lnhtlc.py

@ -14,42 +14,54 @@ class TestHTLCManager(unittest.TestCase):
B = HTLCManager() B = HTLCManager()
ah0, bh0 = H('A', 0), H('B', 0) ah0, bh0 = H('A', 0), H('B', 0)
B.recv_htlc(A.send_htlc(ah0)) B.recv_htlc(A.send_htlc(ah0))
self.assertTrue(B.expect_sig[RECEIVED])
self.assertTrue(A.expect_sig[SENT])
self.assertFalse(B.expect_sig[SENT])
self.assertFalse(A.expect_sig[RECEIVED])
self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1) self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1)
A.recv_htlc(B.send_htlc(bh0)) A.recv_htlc(B.send_htlc(bh0))
self.assertTrue(B.expect_sig[RECEIVED]) self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertTrue(A.expect_sig[SENT]) self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertTrue(A.expect_sig[SENT]) self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertTrue(B.expect_sig[RECEIVED]) self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, bh0)])
self.assertEqual(B.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(LOCAL), [])
self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)])
A.send_ctx() A.send_ctx()
B.recv_ctx() B.recv_ctx()
B.send_ctx() B.send_ctx()
A.recv_ctx() A.recv_ctx()
self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [])
self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)])
B.send_rev() B.send_rev()
A.recv_rev() A.recv_rev()
A.send_rev() A.send_rev()
B.recv_rev() B.recv_rev()
self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)])
A.send_ctx()
B.recv_ctx()
B.send_ctx()
A.recv_ctx()
self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
B.send_rev()
A.recv_rev()
A.send_rev()
B.recv_rev()
self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
def test_single_htlc_full_lifecycle(self): def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool): def htlc_lifecycle(htlc_success: bool):
A = HTLCManager() A = HTLCManager()
B = HTLCManager() B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0))) B.recv_htlc(A.send_htlc(H('A', 0)))
self.assertEqual(len(B.pending_htlcs(REMOTE)), 0) self.assertEqual(len(B.get_htlcs_in_next_ctx(REMOTE)), 0)
self.assertEqual(len(A.pending_htlcs(REMOTE)), 1) self.assertEqual(len(A.get_htlcs_in_next_ctx(REMOTE)), 1)
self.assertEqual(len(B.pending_htlcs(LOCAL)), 1) self.assertEqual(len(B.get_htlcs_in_next_ctx(LOCAL)), 1)
self.assertEqual(len(A.pending_htlcs(LOCAL)), 0) self.assertEqual(len(A.get_htlcs_in_next_ctx(LOCAL)), 0)
A.send_ctx() A.send_ctx()
B.recv_ctx() B.recv_ctx()
B.send_rev() B.send_rev()
@ -58,8 +70,8 @@ class TestHTLCManager(unittest.TestCase):
A.recv_ctx() A.recv_ctx()
A.send_rev() A.send_rev()
B.recv_rev() B.recv_rev()
self.assertEqual(len(A.current_htlcs(LOCAL)), 1) self.assertEqual(len(A.get_htlcs_in_latest_ctx(LOCAL)), 1)
self.assertEqual(len(B.current_htlcs(LOCAL)), 1) self.assertEqual(len(B.get_htlcs_in_latest_ctx(LOCAL)), 1)
if htlc_success: if htlc_success:
B.send_settle(0) B.send_settle(0)
A.recv_settle(0) A.recv_settle(0)
@ -67,47 +79,47 @@ class TestHTLCManager(unittest.TestCase):
B.send_fail(0) B.send_fail(0)
A.recv_fail(0) A.recv_fail(0)
self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)]) self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)])
self.assertNotEqual(A.current_htlcs(LOCAL), []) self.assertNotEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertNotEqual(B.current_htlcs(REMOTE), []) self.assertNotEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(A.pending_htlcs(LOCAL), []) self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [])
self.assertNotEqual(A.pending_htlcs(REMOTE), []) self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
self.assertEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE)) self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual(B.pending_htlcs(REMOTE), []) self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), [])
B.send_ctx() B.send_ctx()
A.recv_ctx() A.recv_ctx()
A.send_rev() # here pending_htlcs(REMOTE) should become empty A.send_rev() # here pending_htlcs(REMOTE) should become empty
self.assertEqual(A.pending_htlcs(REMOTE), []) self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
B.recv_rev() B.recv_rev()
A.send_ctx() A.send_ctx()
B.recv_ctx() B.recv_ctx()
B.send_rev() B.send_rev()
A.recv_rev() A.recv_rev()
self.assertEqual(B.current_htlcs(LOCAL), []) self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertEqual(A.current_htlcs(LOCAL), []) self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertEqual(A.current_htlcs(REMOTE), []) self.assertEqual(A.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(B.current_htlcs(REMOTE), []) self.assertEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(len(A.all_settled_htlcs_ever(LOCAL)), int(htlc_success)) self.assertEqual(len(A.all_settled_htlcs_ever(LOCAL)), int(htlc_success))
self.assertEqual(len(A.sent_in_ctn(2)), int(htlc_success)) self.assertEqual(len(A.sent_in_ctn(2)), int(htlc_success))
self.assertEqual(len(B.received_in_ctn(2)), int(htlc_success)) self.assertEqual(len(B.received_in_ctn(2)), int(htlc_success))
A.recv_htlc(B.send_htlc(H('B', 0))) A.recv_htlc(B.send_htlc(H('B', 0)))
self.assertEqual(A.pending_htlcs(REMOTE), []) self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
self.assertNotEqual(A.pending_htlcs(LOCAL), []) self.assertNotEqual(A.get_htlcs_in_next_ctx(LOCAL), [])
self.assertNotEqual(B.pending_htlcs(REMOTE), []) self.assertNotEqual(B.get_htlcs_in_next_ctx(REMOTE), [])
self.assertEqual(B.pending_htlcs(LOCAL), []) self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [])
B.send_ctx() B.send_ctx()
A.recv_ctx() A.recv_ctx()
A.send_rev() A.send_rev()
B.recv_rev() B.recv_rev()
self.assertNotEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE)) self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual(A.pending_htlcs(LOCAL), A.current_htlcs(LOCAL)) self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual(B.pending_htlcs(REMOTE), B.current_htlcs(REMOTE)) self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), B.get_htlcs_in_latest_ctx(REMOTE))
self.assertNotEqual(B.pending_htlcs(LOCAL), B.pending_htlcs(REMOTE)) self.assertNotEqual(B.get_htlcs_in_next_ctx(LOCAL), B.get_htlcs_in_next_ctx(REMOTE))
htlc_lifecycle(htlc_success=True) htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False) htlc_lifecycle(htlc_success=False)
@ -116,7 +128,8 @@ class TestHTLCManager(unittest.TestCase):
def htlc_lifecycle(htlc_success: bool): def htlc_lifecycle(htlc_success: bool):
A = HTLCManager() A = HTLCManager()
B = HTLCManager() B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0))) ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
A.send_ctx() A.send_ctx()
B.recv_ctx() B.recv_ctx()
B.send_rev() B.send_rev()
@ -127,11 +140,22 @@ class TestHTLCManager(unittest.TestCase):
else: else:
B.send_fail(0) B.send_fail(0)
A.recv_fail(0) A.recv_fail(0)
self.assertEqual(B.pending_htlcs(REMOTE), []) self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_ctx() B.send_ctx()
A.recv_ctx() A.recv_ctx()
A.send_rev() A.send_rev()
B.recv_rev() B.recv_rev()
self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([], A.get_htlcs_in_next_ctx(REMOTE))
htlc_lifecycle(htlc_success=True) htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False) htlc_lifecycle(htlc_success=False)
@ -144,13 +168,38 @@ class TestHTLCManager(unittest.TestCase):
B.send_rev() B.send_rev()
ah0 = H('A', 0) ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0)) B.recv_htlc(A.send_htlc(ah0))
self.assertEqual([], A.current_htlcs(LOCAL)) self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE)) self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.pending_htlcs(LOCAL)) self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([], A.pending_htlcs(REMOTE)) self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.recv_rev() A.recv_rev()
self.assertEqual([], A.current_htlcs(LOCAL)) self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE)) self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.pending_htlcs(LOCAL)) self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.pending_htlcs(REMOTE)) self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.send_ctx()
B.recv_ctx()
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_rev()
A.recv_rev()
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_ctx()
A.recv_ctx()
self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.send_rev()
B.recv_rev()
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))

Loading…
Cancel
Save