Browse Source

lnhtlc: (fix) was locking in too many updates during commit/revoke

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 6 years ago
committed by ThomasV
parent
commit
7431aac5cd
  1. 10
      electrum/lnchannel.py
  2. 106
      electrum/lnhtlc.py
  3. 10
      electrum/lnpeer.py
  4. 17
      electrum/tests/test_lnchannel.py
  5. 153
      electrum/tests/test_lnhtlc.py

10
electrum/lnchannel.py

@ -113,6 +113,10 @@ def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()}
class Channel(Logger):
# note: try to avoid naming ctns/ctxs/etc as "current" and "pending".
# they are ambiguous. Use "oldest_unrevoked" or "latest" or "next".
# TODO enforce this ^
def diagnostic_name(self):
if self.name:
return str(self.name)
@ -154,7 +158,9 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked.deserialize(True)
log = state.get('log')
self.hm = HTLCManager(self.config[LOCAL].ctn, self.config[REMOTE].ctn, log)
self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn,
remote_ctn=self.config[REMOTE].ctn,
log=log)
self.name = name
Logger.__init__(self)
@ -209,6 +215,7 @@ class Channel(Logger):
return self.force_closed or self.get_state() in ['CLOSED', 'CLOSING']
def _check_can_pay(self, amount_msat: int) -> None:
# TODO check if this method uses correct ctns (should use "latest" + 1)
if self.is_closed():
raise PaymentFailure('Channel closed')
if self.get_state() != 'OPEN':
@ -525,6 +532,7 @@ class Channel(Logger):
not be used in the UI cause it fluctuates (commit fee)
"""
# FIXME whose balance? whose ctx?
# FIXME confusing/mixing ctns (should probably use latest_ctn + 1; not oldest_unrevoked + 1)
assert type(subject) is HTLCOwner
return self.balance_minus_outgoing_htlcs(subject, ctx_owner=subject)\
- self.config[-subject].reserve_sat * 1000\

106
electrum/lnhtlc.py

@ -1,46 +1,45 @@
from copy import deepcopy
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, List
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u
class HTLCManager:
def __init__(self, local_ctn=0, remote_ctn=0, log=None):
def __init__(self, *, local_ctn=0, remote_ctn=0, log=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
# ctx_pending[sub] is True iff sub has sent commitment_signed but did not receive revoke_and_ack
# ctx_pending[sub] is True iff sub has received commitment_signed but did not send revoke_and_ack (sub has multiple unrevoked ctxs)
self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted?
# expect_sig[SENT/RECEIVED] is True iff HTLCs have been sent/received but the corresponding commitment_signed has not been received/sent
self.expect_sig = {SENT: False, RECEIVED: False}
if log is None:
initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else:
assert type(log) is dict
log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()}
log = {HTLCOwner(int(sub)): action for sub, action in deepcopy(log).items()}
for sub in (LOCAL, REMOTE):
log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()}
coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()}
coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
# "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['settles'].items()}
log[sub]['fails'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['fails'].items()}
log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
self.log = log
def ctn_latest(self, sub):
def ctn_latest(self, sub: HTLCOwner) -> int:
"""Return the ctn for the latest (newest that has a valid sig) ctx of sub"""
return self.ctn[sub] + int(self.ctx_pending[sub])
def to_save(self):
x = deepcopy(self.log)
log = deepcopy(self.log)
for sub in (LOCAL, REMOTE):
# adds
d = {}
for htlc_id, htlc in x[sub]['adds'].items():
for htlc_id, htlc in log[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
x[sub]['adds'] = d
return x
log[sub]['adds'] = d
return log
def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0}
@ -48,53 +47,55 @@ class HTLCManager:
def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
htlc_id = htlc.htlc_id
adds = self.log[LOCAL]['adds']
assert type(adds) is not str
adds[htlc_id] = htlc
self.log[LOCAL]['adds'][htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1}
self.expect_sig[SENT] = True
return htlc
def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
htlc_id = htlc.htlc_id
self.log[REMOTE]['adds'][htlc_id] = htlc
l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None}
self.expect_sig[RECEIVED] = True
self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None}
def send_settle(self, htlc_id: int) -> None:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_settle(self, htlc_id: int) -> None:
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def send_fail(self, htlc_id: int) -> None:
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def send_ctx(self) -> None:
assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE])
self.ctx_pending[REMOTE] = True
for locked_in in self.log[REMOTE]['locked_in'].values():
if locked_in[REMOTE] is None:
locked_in[REMOTE] = self.ctn_latest(REMOTE)
self.expect_sig[SENT] = False
def recv_ctx(self) -> None:
assert self.ctn_latest(LOCAL) == self.ctn[LOCAL], (self.ctn_latest(LOCAL), self.ctn[LOCAL])
self.ctx_pending[LOCAL] = True
for locked_in in self.log[LOCAL]['locked_in'].values():
if locked_in[LOCAL] is None:
locked_in[LOCAL] = self.ctn_latest(LOCAL)
self.expect_sig[RECEIVED] = False
def send_rev(self) -> None:
self.ctn[LOCAL] += 1
self.ctx_pending[LOCAL] = False
for ctns in self.log[REMOTE]['locked_in'].values():
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[LOCAL][log_action].items():
if ctns[REMOTE] is None:
for ctns in self.log[LOCAL][log_action].values():
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
def recv_rev(self) -> None:
self.ctn[REMOTE] += 1
self.ctx_pending[REMOTE] = False
for htlc_id, ctns in self.log[LOCAL]['locked_in'].items():
if ctns[LOCAL] is None:
#assert ctns[REMOTE] == self.ctn[REMOTE] # FIXME I don't think this assert is correct
for ctns in self.log[LOCAL]['locked_in'].values():
if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[REMOTE][log_action].items():
if ctns[LOCAL] is None:
for ctns in self.log[REMOTE][log_action].values():
if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
@ -113,13 +114,7 @@ class HTLCManager:
# party is the proposer of the HTLCs
party = subject if direction == SENT else subject.inverted()
for htlc_id, ctns in self.log[party]['locked_in'].items():
htlc_height = ctns[subject]
if htlc_height is None:
expect_sig = self.expect_sig[RECEIVED if party != LOCAL else SENT]
include = not expect_sig and ctns[-subject] <= ctn
else:
include = htlc_height <= ctn
if include:
if ctns[subject] is not None and ctns[subject] <= ctn:
settles = self.log[party]['settles']
fails = self.log[party]['fails']
not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn
@ -138,23 +133,20 @@ class HTLCManager:
l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)]
return l
def current_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
"""Return the list of HTLCs in subject's oldest unrevoked ctx."""
def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner
ctn = self.ctn[subject]
return self.htlcs(subject, ctn)
def pending_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
"""Return the list of HTLCs in subject's next ctx (one after oldest unrevoked)."""
def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner
ctn = self.ctn[subject] + 1
ctn = self.ctn_latest(subject)
return self.htlcs(subject, ctn)
def send_settle(self, htlc_id: int) -> None:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_settle(self, htlc_id: int) -> None:
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def get_htlcs_in_next_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner
ctn = self.ctn_latest(subject) + 1
return self.htlcs(subject, ctn)
def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -194,9 +186,3 @@ class HTLCManager:
return [self.log[LOCAL]['adds'][htlc_id]
for htlc_id, ctns in self.log[LOCAL]['settles'].items()
if ctns[LOCAL] == ctn]
def send_fail(self, htlc_id: int) -> None:
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}

10
electrum/lnpeer.py

@ -1024,12 +1024,12 @@ class Peer(Logger):
def maybe_send_commitment(self, chan: Channel):
ctn_to_sign = chan.get_current_ctn(REMOTE) + 1
# if there are no changes, we will not (and must not) send a new commitment
pending, current = chan.hm.pending_htlcs(REMOTE), chan.hm.current_htlcs(REMOTE)
if (pending == current
next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE)
if (next_htlcs == latest_htlcs
and chan.pending_feerate(REMOTE) == chan.constraints.feerate) \
or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]:
return
self.logger.info(f'send_commitment. old number htlcs: {len(current)}, new number htlcs: {len(pending)}')
self.logger.info(f'send_commitment. old number htlcs: {len(latest_htlcs)}, new number htlcs: {len(next_htlcs)}')
sig_64, htlc_sigs = chan.sign_next_commitment()
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
self.sent_commitment_for_ctn_last[chan] = ctn_to_sign
@ -1087,8 +1087,8 @@ class Peer(Logger):
channel_id = payload['channel_id']
chan = self.channels[channel_id]
# make sure there were changes to the ctx, otherwise the remote peer is misbehaving
if (chan.hm.pending_htlcs(LOCAL) == chan.hm.current_htlcs(LOCAL)
and chan.pending_feerate(LOCAL) == chan.constraints.feerate):
if (chan.hm.get_htlcs_in_next_ctx(LOCAL) == chan.hm.get_htlcs_in_latest_ctx(LOCAL)
and chan.pending_feerate(LOCAL) == chan.constraints.feerate):
raise RemoteMisbehaving('received commitment_signed without pending changes')
# make sure ctn is new
ctn_to_recv = chan.get_current_ctn(LOCAL) + 1

17
electrum/tests/test_lnchannel.py

@ -226,6 +226,8 @@ class TestChannel(unittest.TestCase):
self.bob_channel.add_htlc(self.htlc_dict)
self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3)
self.alice_channel.revoke_current_commitment()
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)
def test_SimpleAddSettleWorkflow(self):
@ -279,8 +281,8 @@ class TestChannel(unittest.TestCase):
self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED)
self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL))
self.assertEqual(next(iter(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE)))[0], RECEIVED)
self.assertEqual(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE), bob_channel.hm.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs())
# Bob receives this signature message, and checks that this covers the
@ -291,14 +293,11 @@ class TestChannel(unittest.TestCase):
self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL)))
self.assertEqual(bob_channel.config[REMOTE].ctn, 0)
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
self.assertEqual(bob_channel.included_htlcs(LOCAL, RECEIVED, 1), [htlc])#
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), [])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc])
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 0), [])
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 0), [])
self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 1), [])
@ -323,7 +322,11 @@ class TestChannel(unittest.TestCase):
self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3)
# so far: Alice added htlc, Alice signed.
self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 2) # oldest unrevoked
self.assertEqual(len(alice_channel.pending_commitment(REMOTE).outputs()), 3) # latest
# Alice then processes this revocation, sending her own revocation for
# her prior commitment transaction. Alice shouldn't have any HTLCs to

153
electrum/tests/test_lnhtlc.py

@ -14,42 +14,54 @@ class TestHTLCManager(unittest.TestCase):
B = HTLCManager()
ah0, bh0 = H('A', 0), H('B', 0)
B.recv_htlc(A.send_htlc(ah0))
self.assertTrue(B.expect_sig[RECEIVED])
self.assertTrue(A.expect_sig[SENT])
self.assertFalse(B.expect_sig[SENT])
self.assertFalse(A.expect_sig[RECEIVED])
self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1)
A.recv_htlc(B.send_htlc(bh0))
self.assertTrue(B.expect_sig[RECEIVED])
self.assertTrue(A.expect_sig[SENT])
self.assertTrue(A.expect_sig[SENT])
self.assertTrue(B.expect_sig[RECEIVED])
self.assertEqual(B.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(LOCAL), [])
self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, bh0)])
A.send_ctx()
B.recv_ctx()
B.send_ctx()
A.recv_ctx()
self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [])
self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)])
B.send_rev()
A.recv_rev()
A.send_rev()
B.recv_rev()
self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)])
A.send_ctx()
B.recv_ctx()
B.send_ctx()
A.recv_ctx()
self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
B.send_rev()
A.recv_rev()
A.send_rev()
B.recv_rev()
self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager()
B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0)))
self.assertEqual(len(B.pending_htlcs(REMOTE)), 0)
self.assertEqual(len(A.pending_htlcs(REMOTE)), 1)
self.assertEqual(len(B.pending_htlcs(LOCAL)), 1)
self.assertEqual(len(A.pending_htlcs(LOCAL)), 0)
self.assertEqual(len(B.get_htlcs_in_next_ctx(REMOTE)), 0)
self.assertEqual(len(A.get_htlcs_in_next_ctx(REMOTE)), 1)
self.assertEqual(len(B.get_htlcs_in_next_ctx(LOCAL)), 1)
self.assertEqual(len(A.get_htlcs_in_next_ctx(LOCAL)), 0)
A.send_ctx()
B.recv_ctx()
B.send_rev()
@ -58,8 +70,8 @@ class TestHTLCManager(unittest.TestCase):
A.recv_ctx()
A.send_rev()
B.recv_rev()
self.assertEqual(len(A.current_htlcs(LOCAL)), 1)
self.assertEqual(len(B.current_htlcs(LOCAL)), 1)
self.assertEqual(len(A.get_htlcs_in_latest_ctx(LOCAL)), 1)
self.assertEqual(len(B.get_htlcs_in_latest_ctx(LOCAL)), 1)
if htlc_success:
B.send_settle(0)
A.recv_settle(0)
@ -67,47 +79,47 @@ class TestHTLCManager(unittest.TestCase):
B.send_fail(0)
A.recv_fail(0)
self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)])
self.assertNotEqual(A.current_htlcs(LOCAL), [])
self.assertNotEqual(B.current_htlcs(REMOTE), [])
self.assertNotEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertNotEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(A.pending_htlcs(LOCAL), [])
self.assertNotEqual(A.pending_htlcs(REMOTE), [])
self.assertEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE))
self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [])
self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual(B.pending_htlcs(REMOTE), [])
self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), [])
B.send_ctx()
A.recv_ctx()
A.send_rev() # here pending_htlcs(REMOTE) should become empty
self.assertEqual(A.pending_htlcs(REMOTE), [])
self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
B.recv_rev()
A.send_ctx()
B.recv_ctx()
B.send_rev()
A.recv_rev()
self.assertEqual(B.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(REMOTE), [])
self.assertEqual(B.current_htlcs(REMOTE), [])
self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
self.assertEqual(A.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(len(A.all_settled_htlcs_ever(LOCAL)), int(htlc_success))
self.assertEqual(len(A.sent_in_ctn(2)), int(htlc_success))
self.assertEqual(len(B.received_in_ctn(2)), int(htlc_success))
A.recv_htlc(B.send_htlc(H('B', 0)))
self.assertEqual(A.pending_htlcs(REMOTE), [])
self.assertNotEqual(A.pending_htlcs(LOCAL), [])
self.assertNotEqual(B.pending_htlcs(REMOTE), [])
self.assertEqual(B.pending_htlcs(LOCAL), [])
self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
self.assertNotEqual(A.get_htlcs_in_next_ctx(LOCAL), [])
self.assertNotEqual(B.get_htlcs_in_next_ctx(REMOTE), [])
self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [])
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()
self.assertNotEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE))
self.assertEqual(A.pending_htlcs(LOCAL), A.current_htlcs(LOCAL))
self.assertEqual(B.pending_htlcs(REMOTE), B.current_htlcs(REMOTE))
self.assertNotEqual(B.pending_htlcs(LOCAL), B.pending_htlcs(REMOTE))
self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), B.get_htlcs_in_latest_ctx(REMOTE))
self.assertNotEqual(B.get_htlcs_in_next_ctx(LOCAL), B.get_htlcs_in_next_ctx(REMOTE))
htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False)
@ -116,7 +128,8 @@ class TestHTLCManager(unittest.TestCase):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager()
B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0)))
ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
A.send_ctx()
B.recv_ctx()
B.send_rev()
@ -127,11 +140,22 @@ class TestHTLCManager(unittest.TestCase):
else:
B.send_fail(0)
A.recv_fail(0)
self.assertEqual(B.pending_htlcs(REMOTE), [])
self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()
self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([], A.get_htlcs_in_next_ctx(REMOTE))
htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False)
@ -144,13 +168,38 @@ class TestHTLCManager(unittest.TestCase):
B.send_rev()
ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
self.assertEqual([], A.current_htlcs(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE))
self.assertEqual([], A.pending_htlcs(LOCAL))
self.assertEqual([], A.pending_htlcs(REMOTE))
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.recv_rev()
self.assertEqual([], A.current_htlcs(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.pending_htlcs(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.pending_htlcs(REMOTE))
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.send_ctx()
B.recv_ctx()
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_rev()
A.recv_rev()
self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_ctx()
A.recv_ctx()
self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.send_rev()
B.recv_rev()
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))

Loading…
Cancel
Save