Browse Source

lnchan: make sign_next_commitment revert state

regtest_lnd
Janus 6 years ago
committed by SomberNight
parent
commit
bfa28e5562
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 80
      electrum/lnchan.py
  2. 17
      electrum/tests/test_lnchan.py

80
electrum/lnchan.py

@ -27,6 +27,7 @@ import binascii
import json
from enum import Enum, auto
from typing import Optional, Dict, List, Tuple
from copy import deepcopy
from .util import bfh, PrintError, bh2u
from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
@ -79,21 +80,20 @@ class FeeUpdate(defaultdict):
return self.rate
# implicit return None
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])):
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
"""
This whole class body is so that if you pass a hex-string as payment_hash,
it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
"""
__slots__ = ()
def __new__(cls, *args, **kwargs):
if len(args) > 0:
args = list(args)
if type(args[1]) is str:
args[1] = bfh(args[1])
args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
if 'locked_in' not in kwargs:
kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
else:
kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in'].items()}
return super().__new__(cls, **kwargs)
def decodeAll(d, local):
@ -162,6 +162,7 @@ class Channel(PrintError):
'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
'settles': [], # List[HTLC_ID]
'fails': [], # List[HTLC_ID]
'locked_in': [], # List[HTLC_ID]
}
self.log = {LOCAL: template(), REMOTE: template()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
@ -269,7 +270,8 @@ class Channel(PrintError):
This docstring was adapted from LND.
"""
self.print_error("sign_next_commitment")
self.lock_in_htlc_changes(LOCAL)
old_logs = dict(self.lock_in_htlc_changes(LOCAL))
pending_remote_commitment = self.pending_remote_commitment
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
@ -290,29 +292,28 @@ class Channel(PrintError):
htlc_sig = ecc.sig_string_from_der_sig(sig[:-1])
htlcsigs.append((pending_remote_commitment.htlc_output_indices[htlc.payment_hash], htlc_sig))
for pending_fee in self.fee_mgr:
if not self.constraints.is_initiator:
pending_fee[FUNDEE_SIGNED] = True
if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
pending_fee[FUNDER_SIGNED] = True
self.process_new_offchain_ctx(pending_remote_commitment, ours=False)
htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs]
# we can't know if this message arrives.
# since we shouldn't actually throw away
# failed htlcs yet (or mark htlc locked in),
# roll back the changes that were made
self.log = old_logs
return sig_64, htlcsigs
def lock_in_htlc_changes(self, subject):
for sub in (LOCAL, REMOTE):
for htlc_id in self.log[-sub]['fails']:
adds = self.log[sub]['adds']
htlc = adds.pop(htlc_id)
self.log[-sub]['fails'].clear()
log = self.log[sub]
yield (sub, deepcopy(log))
for htlc_id in log['fails']:
log['adds'].pop(htlc_id)
log['fails'].clear()
for htlc in self.log[subject]['adds'].values():
if htlc.locked_in[subject] is None:
htlc.locked_in[subject] = self.config[subject].ctn
self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
def receive_new_commitment(self, sig, htlc_sigs):
"""
@ -328,7 +329,9 @@ class Channel(PrintError):
This docstring is from LND.
"""
self.print_error("receive_new_commitment")
self.lock_in_htlc_changes(REMOTE)
for _ in self.lock_in_htlc_changes(REMOTE): pass
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
pending_local_commitment = self.pending_local_commitment
@ -443,11 +446,20 @@ class Channel(PrintError):
def receive_revocation(self, revocation) -> Tuple[int, int]:
self.print_error("receive_revocation")
old_logs = dict(self.lock_in_htlc_changes(LOCAL))
cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
if cur_point != derived_point:
self.log = old_logs
raise Exception('revoked secret not for current point')
for pending_fee in self.fee_mgr:
if not self.constraints.is_initiator:
pending_fee[FUNDEE_SIGNED] = True
if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
pending_fee[FUNDER_SIGNED] = True
# FIXME not sure this is correct... but it seems to work
# if there are update_add_htlc msgs between commitment_signed and rev_ack,
# this might break
@ -462,11 +474,11 @@ class Channel(PrintError):
"""
old_amount = htlcsum(self.htlcs(subject, False))
for htlc_id in self.log[-subject]['settles']:
for htlc_id in self.log[subject]['settles']:
adds = self.log[subject]['adds']
htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat)
self.log[-subject]['settles'].clear()
self.log[subject]['settles'].clear()
return old_amount - htlcsum(self.htlcs(subject, False))
@ -588,13 +600,12 @@ class Channel(PrintError):
only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
"""
update_log = self.log[subject]
other_log = self.log[-subject]
res = []
for htlc in update_log['adds'].values():
locked_in = htlc.locked_in[subject]
settled = htlc.htlc_id in other_log['settles']
failed = htlc.htlc_id in other_log['fails']
if locked_in is None:
locked_in = htlc.htlc_id in update_log['locked_in']
settled = htlc.htlc_id in update_log['settles']
failed = htlc.htlc_id in update_log['fails']
if not locked_in:
continue
if only_pending == (settled or failed):
continue
@ -608,23 +619,23 @@ class Channel(PrintError):
self.print_error("settle_htlc")
htlc = self.log[REMOTE]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
self.log[LOCAL]['settles'].append(htlc_id)
self.log[REMOTE]['settles'].append(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle")
htlc = self.log[LOCAL]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
self.log[REMOTE]['settles'].append(htlc_id)
self.log[LOCAL]['settles'].append(htlc_id)
# we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id):
self.print_error("fail_htlc")
self.log[LOCAL]['fails'].append(htlc_id)
self.log[REMOTE]['fails'].append(htlc_id)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
self.log[REMOTE]['fails'].append(htlc_id)
self.log[LOCAL]['fails'].append(htlc_id)
@property
def current_height(self):
@ -654,8 +665,9 @@ class Channel(PrintError):
"""
removed = []
htlcs = []
for i in self.log[subject]['adds'].values():
locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None
log = self.log[subject]
for htlc_id, i in log['adds'].items():
locked_in = htlc_id in log['locked_in']
if locked_in:
htlcs.append(i._asdict())
else:

17
electrum/tests/test_lnchan.py

@ -396,6 +396,23 @@ class TestChannel(unittest.TestCase):
self.alice_channel.add_htlc(new)
self.assertIn('Not enough local balance', cm.exception.args[0])
def test_sign_commitment_is_pure(self):
force_state_transition(self.alice_channel, self.bob_channel)
self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32)
aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
before_signing = self.alice_channel.to_save()
self.alice_channel.sign_next_commitment()
after_signing = self.alice_channel.to_save()
try:
self.assertEqual(before_signing, after_signing)
except:
try:
from deepdiff import DeepDiff
from pprint import pformat
except ImportError:
raise
raise Exception(pformat(DeepDiff(before_signing, after_signing)))
class TestAvailableToSpend(unittest.TestCase):
def test_DesyncHTLCs(self):
alice_channel, bob_channel = create_test_channels()

Loading…
Cancel
Save