Browse Source

lnchan: make sign_next_commitment revert state

regtest_lnd
Janus 6 years ago
committed by SomberNight
parent
commit
bfa28e5562
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 80
      electrum/lnchan.py
  2. 17
      electrum/tests/test_lnchan.py

80
electrum/lnchan.py

@ -27,6 +27,7 @@ import binascii
import json import json
from enum import Enum, auto from enum import Enum, auto
from typing import Optional, Dict, List, Tuple from typing import Optional, Dict, List, Tuple
from copy import deepcopy
from .util import bfh, PrintError, bh2u from .util import bfh, PrintError, bh2u
from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
@ -79,21 +80,20 @@ class FeeUpdate(defaultdict):
return self.rate return self.rate
# implicit return None # implicit return None
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])): class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
"""
This whole class body is so that if you pass a hex-string as payment_hash,
it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
"""
__slots__ = () __slots__ = ()
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if len(args) > 0: if len(args) > 0:
args = list(args) args = list(args)
if type(args[1]) is str: if type(args[1]) is str:
args[1] = bfh(args[1]) args[1] = bfh(args[1])
args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
return super().__new__(cls, *args) return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str: if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash']) kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
if 'locked_in' not in kwargs:
kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
else:
kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in'].items()}
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
def decodeAll(d, local): def decodeAll(d, local):
@ -162,6 +162,7 @@ class Channel(PrintError):
'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc] 'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
'settles': [], # List[HTLC_ID] 'settles': [], # List[HTLC_ID]
'fails': [], # List[HTLC_ID] 'fails': [], # List[HTLC_ID]
'locked_in': [], # List[HTLC_ID]
} }
self.log = {LOCAL: template(), REMOTE: template()} self.log = {LOCAL: template(), REMOTE: template()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]: for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
@ -269,7 +270,8 @@ class Channel(PrintError):
This docstring was adapted from LND. This docstring was adapted from LND.
""" """
self.print_error("sign_next_commitment") self.print_error("sign_next_commitment")
self.lock_in_htlc_changes(LOCAL)
old_logs = dict(self.lock_in_htlc_changes(LOCAL))
pending_remote_commitment = self.pending_remote_commitment pending_remote_commitment = self.pending_remote_commitment
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE]) sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
@ -290,29 +292,28 @@ class Channel(PrintError):
htlc_sig = ecc.sig_string_from_der_sig(sig[:-1]) htlc_sig = ecc.sig_string_from_der_sig(sig[:-1])
htlcsigs.append((pending_remote_commitment.htlc_output_indices[htlc.payment_hash], htlc_sig)) htlcsigs.append((pending_remote_commitment.htlc_output_indices[htlc.payment_hash], htlc_sig))
for pending_fee in self.fee_mgr:
if not self.constraints.is_initiator:
pending_fee[FUNDEE_SIGNED] = True
if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
pending_fee[FUNDER_SIGNED] = True
self.process_new_offchain_ctx(pending_remote_commitment, ours=False) self.process_new_offchain_ctx(pending_remote_commitment, ours=False)
htlcsigs.sort() htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs] htlcsigs = [x[1] for x in htlcsigs]
# we can't know if this message arrives.
# since we shouldn't actually throw away
# failed htlcs yet (or mark htlc locked in),
# roll back the changes that were made
self.log = old_logs
return sig_64, htlcsigs return sig_64, htlcsigs
def lock_in_htlc_changes(self, subject): def lock_in_htlc_changes(self, subject):
for sub in (LOCAL, REMOTE): for sub in (LOCAL, REMOTE):
for htlc_id in self.log[-sub]['fails']: log = self.log[sub]
adds = self.log[sub]['adds'] yield (sub, deepcopy(log))
htlc = adds.pop(htlc_id) for htlc_id in log['fails']:
self.log[-sub]['fails'].clear() log['adds'].pop(htlc_id)
log['fails'].clear()
for htlc in self.log[subject]['adds'].values(): self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
if htlc.locked_in[subject] is None:
htlc.locked_in[subject] = self.config[subject].ctn
def receive_new_commitment(self, sig, htlc_sigs): def receive_new_commitment(self, sig, htlc_sigs):
""" """
@ -328,7 +329,9 @@ class Channel(PrintError):
This docstring is from LND. This docstring is from LND.
""" """
self.print_error("receive_new_commitment") self.print_error("receive_new_commitment")
self.lock_in_htlc_changes(REMOTE)
for _ in self.lock_in_htlc_changes(REMOTE): pass
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
pending_local_commitment = self.pending_local_commitment pending_local_commitment = self.pending_local_commitment
@ -443,11 +446,20 @@ class Channel(PrintError):
def receive_revocation(self, revocation) -> Tuple[int, int]: def receive_revocation(self, revocation) -> Tuple[int, int]:
self.print_error("receive_revocation") self.print_error("receive_revocation")
old_logs = dict(self.lock_in_htlc_changes(LOCAL))
cur_point = self.config[REMOTE].current_per_commitment_point cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True) derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
if cur_point != derived_point: if cur_point != derived_point:
self.log = old_logs
raise Exception('revoked secret not for current point') raise Exception('revoked secret not for current point')
for pending_fee in self.fee_mgr:
if not self.constraints.is_initiator:
pending_fee[FUNDEE_SIGNED] = True
if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
pending_fee[FUNDER_SIGNED] = True
# FIXME not sure this is correct... but it seems to work # FIXME not sure this is correct... but it seems to work
# if there are update_add_htlc msgs between commitment_signed and rev_ack, # if there are update_add_htlc msgs between commitment_signed and rev_ack,
# this might break # this might break
@ -462,11 +474,11 @@ class Channel(PrintError):
""" """
old_amount = htlcsum(self.htlcs(subject, False)) old_amount = htlcsum(self.htlcs(subject, False))
for htlc_id in self.log[-subject]['settles']: for htlc_id in self.log[subject]['settles']:
adds = self.log[subject]['adds'] adds = self.log[subject]['adds']
htlc = adds.pop(htlc_id) htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat) self.settled[subject].append(htlc.amount_msat)
self.log[-subject]['settles'].clear() self.log[subject]['settles'].clear()
return old_amount - htlcsum(self.htlcs(subject, False)) return old_amount - htlcsum(self.htlcs(subject, False))
@ -588,13 +600,12 @@ class Channel(PrintError):
only_pending: require the htlc's settlement to be pending (needs additional signatures/acks) only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
""" """
update_log = self.log[subject] update_log = self.log[subject]
other_log = self.log[-subject]
res = [] res = []
for htlc in update_log['adds'].values(): for htlc in update_log['adds'].values():
locked_in = htlc.locked_in[subject] locked_in = htlc.htlc_id in update_log['locked_in']
settled = htlc.htlc_id in other_log['settles'] settled = htlc.htlc_id in update_log['settles']
failed = htlc.htlc_id in other_log['fails'] failed = htlc.htlc_id in update_log['fails']
if locked_in is None: if not locked_in:
continue continue
if only_pending == (settled or failed): if only_pending == (settled or failed):
continue continue
@ -608,23 +619,23 @@ class Channel(PrintError):
self.print_error("settle_htlc") self.print_error("settle_htlc")
htlc = self.log[REMOTE]['adds'][htlc_id] htlc = self.log[REMOTE]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage) assert htlc.payment_hash == sha256(preimage)
self.log[LOCAL]['settles'].append(htlc_id) self.log[REMOTE]['settles'].append(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices # not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id): def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle") self.print_error("receive_htlc_settle")
htlc = self.log[LOCAL]['adds'][htlc_id] htlc = self.log[LOCAL]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage) assert htlc.payment_hash == sha256(preimage)
self.log[REMOTE]['settles'].append(htlc_id) self.log[LOCAL]['settles'].append(htlc_id)
# we don't save the preimage because we don't need to forward it anyway # we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id): def fail_htlc(self, htlc_id):
self.print_error("fail_htlc") self.print_error("fail_htlc")
self.log[LOCAL]['fails'].append(htlc_id) self.log[REMOTE]['fails'].append(htlc_id)
def receive_fail_htlc(self, htlc_id): def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc") self.print_error("receive_fail_htlc")
self.log[REMOTE]['fails'].append(htlc_id) self.log[LOCAL]['fails'].append(htlc_id)
@property @property
def current_height(self): def current_height(self):
@ -654,8 +665,9 @@ class Channel(PrintError):
""" """
removed = [] removed = []
htlcs = [] htlcs = []
for i in self.log[subject]['adds'].values(): log = self.log[subject]
locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None for htlc_id, i in log['adds'].items():
locked_in = htlc_id in log['locked_in']
if locked_in: if locked_in:
htlcs.append(i._asdict()) htlcs.append(i._asdict())
else: else:

17
electrum/tests/test_lnchan.py

@ -396,6 +396,23 @@ class TestChannel(unittest.TestCase):
self.alice_channel.add_htlc(new) self.alice_channel.add_htlc(new)
self.assertIn('Not enough local balance', cm.exception.args[0]) self.assertIn('Not enough local balance', cm.exception.args[0])
def test_sign_commitment_is_pure(self):
force_state_transition(self.alice_channel, self.bob_channel)
self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32)
aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
before_signing = self.alice_channel.to_save()
self.alice_channel.sign_next_commitment()
after_signing = self.alice_channel.to_save()
try:
self.assertEqual(before_signing, after_signing)
except:
try:
from deepdiff import DeepDiff
from pprint import pformat
except ImportError:
raise
raise Exception(pformat(DeepDiff(before_signing, after_signing)))
class TestAvailableToSpend(unittest.TestCase): class TestAvailableToSpend(unittest.TestCase):
def test_DesyncHTLCs(self): def test_DesyncHTLCs(self):
alice_channel, bob_channel = create_test_channels() alice_channel, bob_channel = create_test_channels()

Loading…
Cancel
Save