Browse Source

ln: merge OpenChannel and HTLCStateMachine

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 7 years ago
committed by ThomasV
parent
commit
7a3551b5df
  1. 6
      gui/qt/channels_list.py
  2. 142
      lib/lnbase.py
  3. 206
      lib/lnhtlc.py
  4. 2
      lib/lnwatcher.py
  5. 101
      lib/lnworker.py
  6. 32
      lib/tests/test_lnhtlc.py

6
gui/qt/channels_list.py

@ -4,13 +4,13 @@ from PyQt5.QtWidgets import *
from electrum.util import inv_dict, bh2u, bfh from electrum.util import inv_dict, bh2u, bfh
from electrum.i18n import _ from electrum.i18n import _
from electrum.lnbase import OpenChannel from electrum.lnhtlc import HTLCStateMachine
from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton
from .amountedit import BTCAmountEdit from .amountedit import BTCAmountEdit
class ChannelsList(MyTreeWidget): class ChannelsList(MyTreeWidget):
update_rows = QtCore.pyqtSignal() update_rows = QtCore.pyqtSignal()
update_single_row = QtCore.pyqtSignal(OpenChannel) update_single_row = QtCore.pyqtSignal(HTLCStateMachine)
def __init__(self, parent): def __init__(self, parent):
MyTreeWidget.__init__(self, parent, self.create_menu, [_('Node ID'), _('Balance'), _('Remote'), _('Status')], 0) MyTreeWidget.__init__(self, parent, self.create_menu, [_('Node ID'), _('Balance'), _('Remote'), _('Status')], 0)
@ -38,7 +38,7 @@ class ChannelsList(MyTreeWidget):
menu.addAction(_("Close channel"), close) menu.addAction(_("Close channel"), close)
menu.exec_(self.viewport().mapToGlobal(position)) menu.exec_(self.viewport().mapToGlobal(position))
@QtCore.pyqtSlot(OpenChannel) @QtCore.pyqtSlot(HTLCStateMachine)
def do_update_single_row(self, chan): def do_update_single_row(self, chan):
for i in range(self.topLevelItemCount()): for i in range(self.topLevelItemCount()):
item = self.topLevelItem(i) item = self.topLevelItem(i)

142
lib/lnbase.py

@ -4,12 +4,53 @@
Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8 Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
""" """
from collections import namedtuple, defaultdict, OrderedDict, defaultdict
Keypair = namedtuple("Keypair", ["pubkey", "privkey"])
Outpoint = namedtuple("Outpoint", ["txid", "output_index"])
ChannelConfig = namedtuple("ChannelConfig", [
"payment_basepoint", "multisig_key", "htlc_basepoint", "delayed_basepoint", "revocation_basepoint",
"to_self_delay", "dust_limit_sat", "max_htlc_value_in_flight_msat", "max_accepted_htlcs"])
OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_msat", "revocation_store", "current_per_commitment_point", "next_htlc_id"])
LocalState = namedtuple("LocalState", ["ctn", "per_commitment_secret_seed", "amount_msat", "next_htlc_id", "funding_locked_received", "was_announced", "current_commitment_signature"])
ChannelConstraints = namedtuple("ChannelConstraints", ["feerate", "capacity", "is_initiator", "funding_txn_minimum_depth"])
#OpenChannel = namedtuple("OpenChannel", ["channel_id", "short_channel_id", "funding_outpoint", "local_config", "remote_config", "remote_state", "local_state", "constraints", "node_id"])
class RevocationStore:
""" taken from lnd """
def __init__(self):
self.buckets = [None] * 48
self.index = 2**48 - 1
def add_next_entry(self, hsh):
new_element = ShachainElement(index=self.index, secret=hsh)
bucket = count_trailing_zeros(self.index)
for i in range(0, bucket):
this_bucket = self.buckets[i]
e = shachain_derive(new_element, this_bucket.index)
if e != this_bucket:
raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
self.buckets[bucket] = new_element
self.index -= 1
def serialize(self):
return {"index": self.index, "buckets": [[bh2u(k.secret), k.index] if k is not None else None for k in self.buckets]}
@staticmethod
def from_json_obj(decoded_json_obj):
store = RevocationStore()
decode = lambda to_decode: ShachainElement(bfh(to_decode[0]), int(to_decode[1]))
store.buckets = [k if k is None else decode(k) for k in decoded_json_obj["buckets"]]
store.index = decoded_json_obj["index"]
return store
def __eq__(self, o):
return type(o) is RevocationStore and self.serialize() == o.serialize()
def __hash__(self):
return hash(json.dumps(self.serialize(), sort_keys=True))
from ecdsa.util import sigdecode_der, sigencode_string_canonize, sigdecode_string from ecdsa.util import sigdecode_der, sigencode_string_canonize, sigdecode_string
from ecdsa.curves import SECP256k1 from ecdsa.curves import SECP256k1
import queue import queue
import traceback import traceback
import json import json
from collections import OrderedDict, defaultdict
import asyncio import asyncio
from concurrent.futures import FIRST_COMPLETED from concurrent.futures import FIRST_COMPLETED
import os import os
@ -18,7 +59,6 @@ import binascii
import hashlib import hashlib
import hmac import hmac
from typing import Sequence, Union, Tuple from typing import Sequence, Union, Tuple
from collections import namedtuple, defaultdict
import cryptography.hazmat.primitives.ciphers.aead as AEAD import cryptography.hazmat.primitives.ciphers.aead as AEAD
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@ -274,17 +314,6 @@ def create_ephemeral_key(privkey):
pub = privkey_to_pubkey(privkey) pub = privkey_to_pubkey(privkey)
return (privkey[:32], pub) return (privkey[:32], pub)
Keypair = namedtuple("Keypair", ["pubkey", "privkey"])
Outpoint = namedtuple("Outpoint", ["txid", "output_index"])
ChannelConfig = namedtuple("ChannelConfig", [
"payment_basepoint", "multisig_key", "htlc_basepoint", "delayed_basepoint", "revocation_basepoint",
"to_self_delay", "dust_limit_sat", "max_htlc_value_in_flight_msat", "max_accepted_htlcs"])
OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_msat", "revocation_store", "current_per_commitment_point", "next_htlc_id"])
LocalState = namedtuple("LocalState", ["ctn", "per_commitment_secret_seed", "amount_msat", "next_htlc_id", "funding_locked_received", "was_announced", "current_commitment_signature"])
ChannelConstraints = namedtuple("ChannelConstraints", ["feerate", "capacity", "is_initiator", "funding_txn_minimum_depth"])
OpenChannel = namedtuple("OpenChannel", ["channel_id", "short_channel_id", "funding_outpoint", "local_config", "remote_config", "remote_state", "local_state", "constraints", "node_id"])
def aiosafe(f): def aiosafe(f):
async def f2(*args, **kwargs): async def f2(*args, **kwargs):
@ -887,14 +916,14 @@ class Peer(PrintError):
# remote commitment transaction # remote commitment transaction
channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index) channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index)
their_revocation_store = RevocationStore() their_revocation_store = RevocationStore()
chan = OpenChannel( chan = {
node_id=self.pubkey, "node_id": self.pubkey,
channel_id=channel_id, "channel_id": channel_id,
short_channel_id=None, "short_channel_id": None,
funding_outpoint=Outpoint(funding_txid, funding_index), "funding_outpoint": Outpoint(funding_txid, funding_index),
local_config=local_config, "local_config": local_config,
remote_config=remote_config, "remote_config": remote_config,
remote_state=RemoteState( "remote_state": RemoteState(
ctn = -1, ctn = -1,
next_per_commitment_point=remote_per_commitment_point, next_per_commitment_point=remote_per_commitment_point,
current_per_commitment_point=None, current_per_commitment_point=None,
@ -902,7 +931,7 @@ class Peer(PrintError):
revocation_store=their_revocation_store, revocation_store=their_revocation_store,
next_htlc_id = 0 next_htlc_id = 0
), ),
local_state=LocalState( "local_state": LocalState(
ctn = -1, ctn = -1,
per_commitment_secret_seed=per_commitment_secret_seed, per_commitment_secret_seed=per_commitment_secret_seed,
amount_msat=local_amount, amount_msat=local_amount,
@ -911,8 +940,8 @@ class Peer(PrintError):
was_announced = False, was_announced = False,
current_commitment_signature = None current_commitment_signature = None
), ),
constraints=ChannelConstraints(capacity=funding_sat, feerate=local_feerate, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth) "constraints": ChannelConstraints(capacity=funding_sat, feerate=local_feerate, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth)
) }
m = HTLCStateMachine(chan) m = HTLCStateMachine(chan)
sig_64, _ = m.sign_next_commitment() sig_64, _ = m.sign_next_commitment()
self.send_message(gen_msg("funding_created", self.send_message(gen_msg("funding_created",
@ -927,7 +956,8 @@ class Peer(PrintError):
# broadcast funding tx # broadcast funding tx
success, _txid = self.network.broadcast_transaction(funding_tx) success, _txid = self.network.broadcast_transaction(funding_tx)
assert success, success assert success, success
m.state = chan._replace(remote_state=chan.remote_state._replace(ctn=0),local_state=chan.local_state._replace(ctn=0, current_commitment_signature=remote_sig)) m.remote_state = m.remote_state._replace(ctn=0)
m.local_state = m.local_state._replace(ctn=0, current_commitment_signature=remote_sig)
return m return m
@aiosafe @aiosafe
@ -943,7 +973,7 @@ class Peer(PrintError):
)) ))
await self.channel_reestablished[chan_id] await self.channel_reestablished[chan_id]
self.channel_state[chan_id] = 'OPENING' self.channel_state[chan_id] = 'OPENING'
if chan.local_state.funding_locked_received and chan.state.short_channel_id: if chan.local_state.funding_locked_received and chan.short_channel_id:
self.mark_open(chan) self.mark_open(chan)
self.network.trigger_callback('channel', chan) self.network.trigger_callback('channel', chan)
@ -988,9 +1018,10 @@ class Peer(PrintError):
their_next_point = payload["next_per_commitment_point"] their_next_point = payload["next_per_commitment_point"]
new_remote_state = chan.remote_state._replace(next_per_commitment_point=their_next_point, current_per_commitment_point=our_next_point) new_remote_state = chan.remote_state._replace(next_per_commitment_point=their_next_point, current_per_commitment_point=our_next_point)
new_local_state = chan.local_state._replace(funding_locked_received = True) new_local_state = chan.local_state._replace(funding_locked_received = True)
chan.state = chan.state._replace(remote_state=new_remote_state, local_state=new_local_state) chan.remote_state=new_remote_state
chan.local_state=new_local_state
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
if chan.state.short_channel_id: if chan.short_channel_id:
self.mark_open(chan) self.mark_open(chan)
def on_network_update(self, chan, funding_tx_depth): def on_network_update(self, chan, funding_tx_depth):
@ -1000,7 +1031,7 @@ class Peer(PrintError):
Runs on the Network thread. Runs on the Network thread.
""" """
if not chan.local_state.was_announced and funding_tx_depth >= 6: if not chan.local_state.was_announced and funding_tx_depth >= 6:
chan.state = chan.state._replace(local_state=chan.local_state._replace(was_announced=True)) chan.local_state=chan.local_state._replace(was_announced=True)
coro = self.handle_announcements(chan) coro = self.handle_announcements(chan)
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
@ -1035,7 +1066,7 @@ class Peer(PrintError):
len=0, len=0,
#features not set (defaults to zeros) #features not set (defaults to zeros)
chain_hash=bytes.fromhex(rev_hex(constants.net.GENESIS)), chain_hash=bytes.fromhex(rev_hex(constants.net.GENESIS)),
short_channel_id=chan.state.short_channel_id, short_channel_id=chan.short_channel_id,
node_id_1=node_ids[0], node_id_1=node_ids[0],
node_id_2=node_ids[1], node_id_2=node_ids[1],
bitcoin_key_1=bitcoin_keys[0], bitcoin_key_1=bitcoin_keys[0],
@ -1051,12 +1082,12 @@ class Peer(PrintError):
return return
assert chan.local_state.funding_locked_received assert chan.local_state.funding_locked_received
self.channel_state[chan.channel_id] = "OPEN" self.channel_state[chan.channel_id] = "OPEN"
self.network.trigger_callback('channel', chan.state) self.network.trigger_callback('channel', chan)
# add channel to database # add channel to database
sorted_keys = list(sorted([self.pubkey, self.lnworker.pubkey])) sorted_keys = list(sorted([self.pubkey, self.lnworker.pubkey]))
self.channel_db.on_channel_announcement({"short_channel_id": chan.state.short_channel_id, "node_id_1": sorted_keys[0], "node_id_2": sorted_keys[1]}) self.channel_db.on_channel_announcement({"short_channel_id": chan.short_channel_id, "node_id_1": sorted_keys[0], "node_id_2": sorted_keys[1]})
self.channel_db.on_channel_update({"short_channel_id": chan.state.short_channel_id, 'flags': b'\x01', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'}) self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'flags': b'\x01', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'})
self.channel_db.on_channel_update({"short_channel_id": chan.state.short_channel_id, 'flags': b'\x00', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'}) self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'flags': b'\x00', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'})
self.print_error("CHANNEL OPENING COMPLETED") self.print_error("CHANNEL OPENING COMPLETED")
@ -1077,7 +1108,7 @@ class Peer(PrintError):
len=0, len=0,
#features not set (defaults to zeros) #features not set (defaults to zeros)
chain_hash=bytes.fromhex(rev_hex(constants.net.GENESIS)), chain_hash=bytes.fromhex(rev_hex(constants.net.GENESIS)),
short_channel_id=chan.state.short_channel_id, short_channel_id=chan.short_channel_id,
node_id_1=node_ids[0], node_id_1=node_ids[0],
node_id_2=node_ids[1], node_id_2=node_ids[1],
bitcoin_key_1=bitcoin_keys[0], bitcoin_key_1=bitcoin_keys[0],
@ -1089,7 +1120,7 @@ class Peer(PrintError):
node_signature = ecc.ECPrivkey(self.privkey).sign(h, sigencode_string_canonize, sigdecode_string) node_signature = ecc.ECPrivkey(self.privkey).sign(h, sigencode_string_canonize, sigdecode_string)
self.send_message(gen_msg("announcement_signatures", self.send_message(gen_msg("announcement_signatures",
channel_id=chan.channel_id, channel_id=chan.channel_id,
short_channel_id=chan.state.short_channel_id, short_channel_id=chan.short_channel_id,
node_signature=node_signature, node_signature=node_signature,
bitcoin_signature=bitcoin_signature bitcoin_signature=bitcoin_signature
)) ))
@ -1186,7 +1217,7 @@ class Peer(PrintError):
self.revoke(chan) self.revoke(chan)
# TODO process above commitment transactions # TODO process above commitment transactions
bare_ctx = make_commitment_using_open_channel(chan.state, chan.remote_state.ctn + 1, False, chan.remote_state.next_per_commitment_point, bare_ctx = make_commitment_using_open_channel(chan, chan.remote_state.ctn + 1, False, chan.remote_state.next_per_commitment_point,
msat_remote, msat_local) msat_remote, msat_local)
sig_64 = sign_and_get_sig_string(bare_ctx, chan.local_config, chan.remote_config) sig_64 = sign_and_get_sig_string(bare_ctx, chan.local_config, chan.remote_config)
@ -1248,9 +1279,9 @@ class Peer(PrintError):
self.send_message(gen_msg("update_fulfill_htlc", channel_id=channel_id, id=htlc_id, payment_preimage=payment_preimage)) self.send_message(gen_msg("update_fulfill_htlc", channel_id=channel_id, id=htlc_id, payment_preimage=payment_preimage))
# remote commitment transaction without htlcs # remote commitment transaction without htlcs
bare_ctx = make_commitment_using_open_channel(m.state, m.state.remote_state.ctn + 1, False, m.state.remote_state.next_per_commitment_point, bare_ctx = make_commitment_using_open_channel(m, m.remote_state.ctn + 1, False, m.remote_state.next_per_commitment_point,
m.state.remote_state.amount_msat - expected_received_msat, m.state.local_state.amount_msat + expected_received_msat) m.remote_state.amount_msat - expected_received_msat, m.local_state.amount_msat + expected_received_msat)
sig_64 = sign_and_get_sig_string(bare_ctx, m.state.local_config, m.state.remote_config) sig_64 = sign_and_get_sig_string(bare_ctx, m.local_config, m.remote_config)
self.send_message(gen_msg("commitment_signed", channel_id=channel_id, signature=sig_64, num_htlcs=0)) self.send_message(gen_msg("commitment_signed", channel_id=channel_id, signature=sig_64, num_htlcs=0))
await self.receive_revoke(chan) await self.receive_revoke(chan)
@ -1265,7 +1296,7 @@ class Peer(PrintError):
self.print_error("commitment_signed", payload) self.print_error("commitment_signed", payload)
channel_id = payload['channel_id'] channel_id = payload['channel_id']
chan = self.channels[channel_id] chan = self.channels[channel_id]
chan.state = chan.state._replace(local_state=chan.local_state._replace(current_commitment_signature=payload['signature'])) chan.local_state=chan.local_state._replace(current_commitment_signature=payload['signature'])
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self.commitment_signed[channel_id].put_nowait(payload) self.commitment_signed[channel_id].put_nowait(payload)
@ -1312,32 +1343,3 @@ def count_trailing_zeros(index):
ShachainElement = namedtuple("ShachainElement", ["secret", "index"]) ShachainElement = namedtuple("ShachainElement", ["secret", "index"])
ShachainElement.__str__ = lambda self: "ShachainElement(" + bh2u(self.secret) + "," + str(self.index) + ")" ShachainElement.__str__ = lambda self: "ShachainElement(" + bh2u(self.secret) + "," + str(self.index) + ")"
class RevocationStore:
""" taken from lnd """
def __init__(self):
self.buckets = [None] * 48
self.index = 2**48 - 1
def add_next_entry(self, hsh):
new_element = ShachainElement(index=self.index, secret=hsh)
bucket = count_trailing_zeros(self.index)
for i in range(0, bucket):
this_bucket = self.buckets[i]
e = shachain_derive(new_element, this_bucket.index)
if e != this_bucket:
raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
self.buckets[bucket] = new_element
self.index -= 1
def serialize(self):
return {"index": self.index, "buckets": [[bh2u(k.secret), k.index] if k is not None else None for k in self.buckets]}
@staticmethod
def from_json_obj(decoded_json_obj):
store = RevocationStore()
decode = lambda to_decode: ShachainElement(bfh(to_decode[0]), int(to_decode[1]))
store.buckets = [k if k is None else decode(k) for k in decoded_json_obj["buckets"]]
store.index = decoded_json_obj["index"]
return store
def __eq__(self, o):
return type(o) is RevocationStore and self.serialize() == o.serialize()
def __hash__(self):
return hash(json.dumps(self.serialize(), sort_keys=True))

206
lib/lnhtlc.py

@ -1,4 +1,6 @@
# ported from lnd 42de4400bff5105352d0552155f73589166d162b # ported from lnd 42de4400bff5105352d0552155f73589166d162b
import binascii
import json
from ecdsa.util import sigencode_string_canonize, sigdecode_der from ecdsa.util import sigencode_string_canonize, sigdecode_der
from .util import bfh, PrintError from .util import bfh, PrintError
from .bitcoin import Hash from .bitcoin import Hash
@ -7,6 +9,7 @@ from ecdsa.curves import SECP256k1
from .crypto import sha256 from .crypto import sha256
from . import ecc from . import ecc
from . import lnbase from . import lnbase
from .lnbase import Outpoint, ChannelConfig, LocalState, RemoteState, Keypair, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore
HTLC_TIMEOUT_WEIGHT = lnbase.HTLC_TIMEOUT_WEIGHT HTLC_TIMEOUT_WEIGHT = lnbase.HTLC_TIMEOUT_WEIGHT
HTLC_SUCCESS_WEIGHT = lnbase.HTLC_SUCCESS_WEIGHT HTLC_SUCCESS_WEIGHT = lnbase.HTLC_SUCCESS_WEIGHT
@ -38,6 +41,24 @@ class UpdateAddHtlc:
def __repr__(self): def __repr__(self):
return "UpdateAddHtlc" + str(self.as_tuple()) return "UpdateAddHtlc" + str(self.as_tuple())
is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key")
def maybeDecode(k, v):
assert type(v) is not list
if k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature"] and v is not None:
return binascii.unhexlify(v)
return v
def decodeAll(v):
return {i: maybeDecode(i, j) for i, j in v.items()} if isinstance(v, dict) else v
def typeWrap(k, v, local):
if is_key(k):
if local:
return Keypair(**v)
else:
return OnlyPubkeyKeypair(**v)
return v
class HTLCStateMachine(PrintError): class HTLCStateMachine(PrintError):
def lookup_htlc(self, log, htlc_id): def lookup_htlc(self, log, htlc_id):
@ -52,7 +73,33 @@ class HTLCStateMachine(PrintError):
return str(self.name) return str(self.name)
def __init__(self, state, name = None): def __init__(self, state, name = None):
self.state = state self.local_config = state["local_config"]
if type(self.local_config) is not ChannelConfig:
new_local_config = {k: typeWrap(k, decodeAll(v), True) for k, v in self.local_config.items()}
self.local_config = ChannelConfig(**new_local_config)
self.remote_config = state["remote_config"]
if type(self.remote_config) is not ChannelConfig:
new_remote_config = {k: typeWrap(k, decodeAll(v), False) for k, v in self.remote_config.items()}
self.remote_config = ChannelConfig(**new_remote_config)
self.local_state = state["local_state"]
if type(self.local_state) is not LocalState:
self.local_state = LocalState(**decodeAll(self.local_state))
self.remote_state = state["remote_state"]
if type(self.remote_state) is not RemoteState:
self.remote_state = RemoteState(**decodeAll(self.remote_state))
if type(self.remote_state.revocation_store) is not RevocationStore:
self.remote_state = self.remote_state._replace(revocation_store = RevocationStore.from_json_obj(self.remote_state.revocation_store))
self.channel_id = maybeDecode("channel_id", state["channel_id"]) if type(state["channel_id"]) is not bytes else state["channel_id"]
self.constraints = ChannelConstraints(**decodeAll(state["constraints"])) if type(state["constraints"]) is not ChannelConstraints else state["constraints"]
self.funding_outpoint = Outpoint(**decodeAll(state["funding_outpoint"])) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"]
self.node_id = maybeDecode("node_id", state["node_id"]) if type(state["node_id"]) is not bytes else state["node_id"]
self.short_channel_id = maybeDecode("short_channel_id", state["short_channel_id"]) if type(state["short_channel_id"]) is not bytes else state["short_channel_id"]
self.local_update_log = [] self.local_update_log = []
self.remote_update_log = [] self.remote_update_log = []
@ -70,8 +117,8 @@ class HTLCStateMachine(PrintError):
assert type(htlc) is UpdateAddHtlc assert type(htlc) is UpdateAddHtlc
self.local_update_log.append(htlc) self.local_update_log.append(htlc)
self.print_error("add_htlc") self.print_error("add_htlc")
htlc_id = self.state.local_state.next_htlc_id htlc_id = self.local_state.next_htlc_id
self.state = self.state._replace(local_state=self.state.local_state._replace(next_htlc_id=htlc_id + 1)) self.local_state=self.local_state._replace(next_htlc_id=htlc_id + 1)
htlc.htlc_id = htlc_id htlc.htlc_id = htlc_id
return htlc_id return htlc_id
@ -84,8 +131,8 @@ class HTLCStateMachine(PrintError):
self.print_error("receive_htlc") self.print_error("receive_htlc")
assert type(htlc) is UpdateAddHtlc assert type(htlc) is UpdateAddHtlc
self.remote_update_log.append(htlc) self.remote_update_log.append(htlc)
htlc_id = self.state.remote_state.next_htlc_id htlc_id = self.remote_state.next_htlc_id
self.state = self.state._replace(remote_state=self.state.remote_state._replace(next_htlc_id=htlc_id + 1)) self.remote_state=self.remote_state._replace(next_htlc_id=htlc_id + 1)
htlc.htlc_id = htlc_id htlc.htlc_id = htlc_id
return htlc_id return htlc_id
@ -105,14 +152,14 @@ class HTLCStateMachine(PrintError):
from .lnbase import sign_and_get_sig_string, derive_privkey, make_htlc_tx_with_open_channel from .lnbase import sign_and_get_sig_string, derive_privkey, make_htlc_tx_with_open_channel
for htlc in self.local_update_log: for htlc in self.local_update_log:
if not type(htlc) is UpdateAddHtlc: continue if not type(htlc) is UpdateAddHtlc: continue
if htlc.l_locked_in is None: htlc.l_locked_in = self.state.local_state.ctn if htlc.l_locked_in is None: htlc.l_locked_in = self.local_state.ctn
self.print_error("sign_next_commitment") self.print_error("sign_next_commitment")
sig_64 = sign_and_get_sig_string(self.remote_commitment, self.state.local_config, self.state.remote_config) sig_64 = sign_and_get_sig_string(self.remote_commitment, self.local_config, self.remote_config)
their_remote_htlc_privkey_number = derive_privkey( their_remote_htlc_privkey_number = derive_privkey(
int.from_bytes(self.state.local_config.htlc_basepoint.privkey, 'big'), int.from_bytes(self.local_config.htlc_basepoint.privkey, 'big'),
self.state.remote_state.next_per_commitment_point) self.remote_state.next_per_commitment_point)
their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, 'big') their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, 'big')
for_us = False for_us = False
@ -122,11 +169,11 @@ class HTLCStateMachine(PrintError):
assert len(htlcs) <= 1 assert len(htlcs) <= 1
for htlc in htlcs: for htlc in htlcs:
weight = lnbase.HTLC_SUCCESS_WEIGHT if we_receive else lnbase.HTLC_TIMEOUT_WEIGHT weight = lnbase.HTLC_SUCCESS_WEIGHT if we_receive else lnbase.HTLC_TIMEOUT_WEIGHT
if htlc.amount_msat // 1000 - weight * (self.state.constraints.feerate // 1000) < self.state.remote_config.dust_limit_sat: if htlc.amount_msat // 1000 - weight * (self.constraints.feerate // 1000) < self.remote_config.dust_limit_sat:
continue continue
original_htlc_output_index = 0 original_htlc_output_index = 0
args = [self.state.remote_state.next_per_commitment_point, for_us, we_receive, htlc.amount_msat + htlc.total_fee, htlc.cltv_expiry, htlc.payment_hash, self.remote_commitment, original_htlc_output_index] args = [self.remote_state.next_per_commitment_point, for_us, we_receive, htlc.amount_msat + htlc.total_fee, htlc.cltv_expiry, htlc.payment_hash, self.remote_commitment, original_htlc_output_index]
htlc_tx = make_htlc_tx_with_open_channel(self.state, *args) htlc_tx = make_htlc_tx_with_open_channel(self, *args)
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey)) sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey))
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order()) r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order()) htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order())
@ -150,12 +197,12 @@ class HTLCStateMachine(PrintError):
self.print_error("receive_new_commitment") self.print_error("receive_new_commitment")
for htlc in self.remote_update_log: for htlc in self.remote_update_log:
if not type(htlc) is UpdateAddHtlc: continue if not type(htlc) is UpdateAddHtlc: continue
if htlc.r_locked_in is None: htlc.r_locked_in = self.state.remote_state.ctn if htlc.r_locked_in is None: htlc.r_locked_in = self.remote_state.ctn
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
preimage_hex = self.local_commitment.serialize_preimage(0) preimage_hex = self.local_commitment.serialize_preimage(0)
pre_hash = Hash(bfh(preimage_hex)) pre_hash = Hash(bfh(preimage_hex))
if not ecc.verify_signature(self.state.remote_config.multisig_key.pubkey, sig, pre_hash): if not ecc.verify_signature(self.remote_config.multisig_key.pubkey, sig, pre_hash):
raise Exception('failed verifying signature of our updated commitment transaction: ' + str(sig)) raise Exception('failed verifying signature of our updated commitment transaction: ' + str(sig))
_, this_point, _ = self.points _, this_point, _ = self.points
@ -166,9 +213,9 @@ class HTLCStateMachine(PrintError):
payment_hash = self.htlcs_in_remote[0].payment_hash payment_hash = self.htlcs_in_remote[0].payment_hash
amount_msat = self.htlcs_in_remote[0].amount_msat amount_msat = self.htlcs_in_remote[0].amount_msat
cltv_expiry = self.htlcs_in_remote[0].cltv_expiry cltv_expiry = self.htlcs_in_remote[0].cltv_expiry
htlc_tx = make_htlc_tx_with_open_channel(self.state, this_point, True, we_receive, amount_msat, cltv_expiry, payment_hash, self.local_commitment, 0) htlc_tx = make_htlc_tx_with_open_channel(self, this_point, True, we_receive, amount_msat, cltv_expiry, payment_hash, self.local_commitment, 0)
pre_hash = Hash(bfh(htlc_tx.serialize_preimage(0))) pre_hash = Hash(bfh(htlc_tx.serialize_preimage(0)))
remote_htlc_pubkey = derive_pubkey(self.state.remote_config.htlc_basepoint.pubkey, this_point) remote_htlc_pubkey = derive_pubkey(self.remote_config.htlc_basepoint.pubkey, this_point)
if not ecc.verify_signature(remote_htlc_pubkey, htlc_sigs[0], pre_hash): if not ecc.verify_signature(remote_htlc_pubkey, htlc_sigs[0], pre_hash):
raise Exception("failed verifying signature an HTLC tx spending from one of our commit tx'es HTLC outputs") raise Exception("failed verifying signature an HTLC tx spending from one of our commit tx'es HTLC outputs")
@ -192,15 +239,13 @@ class HTLCStateMachine(PrintError):
if self.pending_feerate is not None: if self.pending_feerate is not None:
new_feerate = self.pending_feerate new_feerate = self.pending_feerate
else: else:
new_feerate = self.state.constraints.feerate new_feerate = self.constraints.feerate
self.state = self.state._replace( self.local_state=self.local_state._replace(
local_state=self.state.local_state._replace( ctn=self.local_state.ctn + 1
ctn=self.state.local_state.ctn + 1
),
constraints=self.state.constraints._replace(
feerate=new_feerate
) )
self.constraints=self.constraints._replace(
feerate=new_feerate
) )
return RevokeAndAck(last_secret, next_point), "current htlcs" return RevokeAndAck(last_secret, next_point), "current htlcs"
@ -208,14 +253,13 @@ class HTLCStateMachine(PrintError):
@property @property
def points(self): def points(self):
from .lnbase import get_per_commitment_secret_from_seed, secret_to_pubkey from .lnbase import get_per_commitment_secret_from_seed, secret_to_pubkey
chan = self.state last_small_num = self.local_state.ctn
last_small_num = chan.local_state.ctn
next_small_num = last_small_num + 2 next_small_num = last_small_num + 2
this_small_num = last_small_num + 1 this_small_num = last_small_num + 1
last_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-last_small_num-1) last_secret = get_per_commitment_secret_from_seed(self.local_state.per_commitment_secret_seed, 2**48-last_small_num-1)
this_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-this_small_num-1) this_secret = get_per_commitment_secret_from_seed(self.local_state.per_commitment_secret_seed, 2**48-this_small_num-1)
this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big')) this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big'))
next_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-next_small_num-1) next_secret = get_per_commitment_secret_from_seed(self.local_state.per_commitment_secret_seed, 2**48-next_small_num-1)
next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big')) next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big'))
return last_secret, this_point, next_point return last_secret, this_point, next_point
@ -272,22 +316,20 @@ class HTLCStateMachine(PrintError):
received_fees = sum(x.total_fee for x in to_remove) received_fees = sum(x.total_fee for x in to_remove)
self.state.remote_state.revocation_store.add_next_entry(revocation.per_commitment_secret) self.remote_state.revocation_store.add_next_entry(revocation.per_commitment_secret)
next_point = self.state.remote_state.next_per_commitment_point next_point = self.remote_state.next_per_commitment_point
print("RECEIVED", received_this_batch) print("RECEIVED", received_this_batch)
print("SENT", sent_this_batch) print("SENT", sent_this_batch)
self.state = self.state._replace( self.remote_state=self.remote_state._replace(
remote_state=self.state.remote_state._replace( ctn=self.remote_state.ctn + 1,
ctn=self.state.remote_state.ctn + 1,
current_per_commitment_point=next_point, current_per_commitment_point=next_point,
next_per_commitment_point=revocation.next_per_commitment_point, next_per_commitment_point=revocation.next_per_commitment_point,
amount_msat=self.state.remote_state.amount_msat + (sent_this_batch - received_this_batch) + sent_fees - received_fees amount_msat=self.remote_state.amount_msat + (sent_this_batch - received_this_batch) + sent_fees - received_fees
),
local_state=self.state.local_state._replace(
amount_msat = self.state.local_state.amount_msat + (received_this_batch - sent_this_batch) - sent_fees + received_fees
) )
self.local_state=self.local_state._replace(
amount_msat = self.local_state.amount_msat + (received_this_batch - sent_this_batch) - sent_fees + received_fees
) )
@staticmethod @staticmethod
@ -306,9 +348,9 @@ class HTLCStateMachine(PrintError):
htlc_value_remote, total_fee_remote = self.htlcsum(self.htlcs_in_remote) htlc_value_remote, total_fee_remote = self.htlcsum(self.htlcs_in_remote)
total_fee_local += local_settled_fee total_fee_local += local_settled_fee
total_fee_remote += remote_settled_fee total_fee_remote += remote_settled_fee
local_msat = self.state.local_state.amount_msat -\ local_msat = self.local_state.amount_msat -\
htlc_value_local + remote_settled_value - local_settled_value htlc_value_local + remote_settled_value - local_settled_value
remote_msat = self.state.remote_state.amount_msat -\ remote_msat = self.remote_state.amount_msat -\
htlc_value_remote + local_settled_value - remote_settled_value htlc_value_remote + local_settled_value - remote_settled_value
return remote_msat, total_fee_remote, local_msat, total_fee_local return remote_msat, total_fee_remote, local_msat, total_fee_local
@ -319,17 +361,17 @@ class HTLCStateMachine(PrintError):
assert local_msat >= 0 assert local_msat >= 0
assert remote_msat >= 0 assert remote_msat >= 0
this_point = self.state.remote_state.next_per_commitment_point this_point = self.remote_state.next_per_commitment_point
remote_htlc_pubkey = derive_pubkey(self.state.remote_config.htlc_basepoint.pubkey, this_point) remote_htlc_pubkey = derive_pubkey(self.remote_config.htlc_basepoint.pubkey, this_point)
local_htlc_pubkey = derive_pubkey(self.state.local_config.htlc_basepoint.pubkey, this_point) local_htlc_pubkey = derive_pubkey(self.local_config.htlc_basepoint.pubkey, this_point)
local_revocation_pubkey = derive_blinded_pubkey(self.state.local_config.revocation_basepoint.pubkey, this_point) local_revocation_pubkey = derive_blinded_pubkey(self.local_config.revocation_basepoint.pubkey, this_point)
trimmed = 0 trimmed = 0
htlcs_in_local = [] htlcs_in_local = []
for htlc in self.htlcs_in_local: for htlc in self.htlcs_in_local:
if htlc.amount_msat // 1000 - lnbase.HTLC_SUCCESS_WEIGHT * (self.state.constraints.feerate // 1000) < self.state.remote_config.dust_limit_sat: if htlc.amount_msat // 1000 - lnbase.HTLC_SUCCESS_WEIGHT * (self.constraints.feerate // 1000) < self.remote_config.dust_limit_sat:
trimmed += htlc.amount_msat // 1000 trimmed += htlc.amount_msat // 1000
continue continue
htlcs_in_local.append( htlcs_in_local.append(
@ -337,13 +379,13 @@ class HTLCStateMachine(PrintError):
htlcs_in_remote = [] htlcs_in_remote = []
for htlc in self.htlcs_in_remote: for htlc in self.htlcs_in_remote:
if htlc.amount_msat // 1000 - lnbase.HTLC_TIMEOUT_WEIGHT * (self.state.constraints.feerate // 1000) < self.state.remote_config.dust_limit_sat: if htlc.amount_msat // 1000 - lnbase.HTLC_TIMEOUT_WEIGHT * (self.constraints.feerate // 1000) < self.remote_config.dust_limit_sat:
trimmed += htlc.amount_msat // 1000 trimmed += htlc.amount_msat // 1000
continue continue
htlcs_in_remote.append( htlcs_in_remote.append(
( make_offered_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash), htlc.amount_msat + htlc.total_fee)) ( make_offered_htlc(local_revocation_pubkey, local_htlc_pubkey, remote_htlc_pubkey, htlc.payment_hash), htlc.amount_msat + htlc.total_fee))
commit = make_commitment_using_open_channel(self.state, self.state.remote_state.ctn + 1, commit = make_commitment_using_open_channel(self, self.remote_state.ctn + 1,
False, this_point, False, this_point,
remote_msat - total_fee_remote, local_msat - total_fee_local, htlcs_in_local + htlcs_in_remote, trimmed) remote_msat - total_fee_remote, local_msat - total_fee_local, htlcs_in_local + htlcs_in_remote, trimmed)
return commit return commit
@ -357,15 +399,15 @@ class HTLCStateMachine(PrintError):
_, this_point, _ = self.points _, this_point, _ = self.points
remote_htlc_pubkey = derive_pubkey(self.state.remote_config.htlc_basepoint.pubkey, this_point) remote_htlc_pubkey = derive_pubkey(self.remote_config.htlc_basepoint.pubkey, this_point)
local_htlc_pubkey = derive_pubkey(self.state.local_config.htlc_basepoint.pubkey, this_point) local_htlc_pubkey = derive_pubkey(self.local_config.htlc_basepoint.pubkey, this_point)
remote_revocation_pubkey = derive_blinded_pubkey(self.state.remote_config.revocation_basepoint.pubkey, this_point) remote_revocation_pubkey = derive_blinded_pubkey(self.remote_config.revocation_basepoint.pubkey, this_point)
trimmed = 0 trimmed = 0
htlcs_in_local = [] htlcs_in_local = []
for htlc in self.htlcs_in_local: for htlc in self.htlcs_in_local:
if htlc.amount_msat // 1000 - lnbase.HTLC_TIMEOUT_WEIGHT * (self.state.constraints.feerate // 1000) < self.state.local_config.dust_limit_sat: if htlc.amount_msat // 1000 - lnbase.HTLC_TIMEOUT_WEIGHT * (self.constraints.feerate // 1000) < self.local_config.dust_limit_sat:
trimmed += htlc.amount_msat // 1000 trimmed += htlc.amount_msat // 1000
continue continue
htlcs_in_local.append( htlcs_in_local.append(
@ -373,13 +415,13 @@ class HTLCStateMachine(PrintError):
htlcs_in_remote = [] htlcs_in_remote = []
for htlc in self.htlcs_in_remote: for htlc in self.htlcs_in_remote:
if htlc.amount_msat // 1000 - lnbase.HTLC_SUCCESS_WEIGHT * (self.state.constraints.feerate // 1000) < self.state.local_config.dust_limit_sat: if htlc.amount_msat // 1000 - lnbase.HTLC_SUCCESS_WEIGHT * (self.constraints.feerate // 1000) < self.local_config.dust_limit_sat:
trimmed += htlc.amount_msat // 1000 trimmed += htlc.amount_msat // 1000
continue continue
htlcs_in_remote.append( htlcs_in_remote.append(
( make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat + htlc.total_fee)) ( make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, htlc.payment_hash, htlc.cltv_expiry), htlc.amount_msat + htlc.total_fee))
commit = make_commitment_using_open_channel(self.state, self.state.local_state.ctn + 1, commit = make_commitment_using_open_channel(self, self.local_state.ctn + 1,
True, this_point, True, this_point,
local_msat - total_fee_local, remote_msat - total_fee_remote, htlcs_in_local + htlcs_in_remote, trimmed) local_msat - total_fee_local, remote_msat - total_fee_remote, htlcs_in_local + htlcs_in_remote, trimmed)
return commit return commit
@ -392,7 +434,7 @@ class HTLCStateMachine(PrintError):
for htlc in update_log: for htlc in update_log:
if type(htlc) is not UpdateAddHtlc: if type(htlc) is not UpdateAddHtlc:
continue continue
height = (self.state.local_state.ctn if subject == "remote" else self.state.remote_state.ctn) height = (self.local_state.ctn if subject == "remote" else self.remote_state.ctn)
locked_in = (htlc.r_locked_in if subject == "remote" else htlc.l_locked_in) locked_in = (htlc.r_locked_in if subject == "remote" else htlc.l_locked_in)
if locked_in is None or just_unsettled == (SettleHtlc(htlc.htlc_id) in other_log): if locked_in is None or just_unsettled == (SettleHtlc(htlc.htlc_id) in other_log):
@ -432,15 +474,15 @@ class HTLCStateMachine(PrintError):
@property @property
def l_current_height(self): def l_current_height(self):
return self.state.local_state.ctn return self.local_state.ctn
@property @property
def r_current_height(self): def r_current_height(self):
return self.state.remote_state.ctn return self.remote_state.ctn
@property @property
def local_commit_fee(self): def local_commit_fee(self):
return self.state.constraints.capacity - sum(x[2] for x in self.local_commitment.outputs()) return self.constraints.capacity - sum(x[2] for x in self.local_commitment.outputs())
def update_fee(self, fee): def update_fee(self, fee):
self.pending_feerate = fee self.pending_feerate = fee
@ -448,22 +490,36 @@ class HTLCStateMachine(PrintError):
def receive_update_fee(self, fee): def receive_update_fee(self, fee):
self.pending_feerate = fee self.pending_feerate = fee
@property def to_save(self):
def local_state(self): return {
return self.state.local_state "local_config": self.local_config,
"remote_config": self.remote_config,
@property "local_state": self.local_state,
def remote_state(self): "remote_state": self.remote_state,
return self.state.remote_state "channel_id": self.channel_id,
"short_channel_id": self.short_channel_id,
@property "constraints": self.constraints,
def remote_config(self): "funding_outpoint": self.funding_outpoint,
return self.state.remote_config "node_id": self.node_id,
"channel_id": self.channel_id
@property }
def local_config(self):
return self.state.local_config def serialize(self):
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
@property serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()}
def channel_id(self): class MyJsonEncoder(json.JSONEncoder):
return self.state.channel_id def default(self, o):
if isinstance(o, bytes):
return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
return super(MyJsonEncoder, self)
dumped = MyJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)
reconstructed = HTLCStateMachine(roundtripped)
if reconstructed.to_save() != self.to_save():
raise Exception("Channels did not roundtrip serialization without changes:\n" + repr(reconstructed.to_save()) + "\n" + repr(self.to_save()))
return roundtripped
def __str__(self):
return self.serialize()

2
lib/lnwatcher.py

@ -15,7 +15,7 @@ class LNWatcher(PrintError):
return response['params'], response['result'] return response['params'], response['result']
def watch_channel(self, chan, callback): def watch_channel(self, chan, callback):
script = funding_output_script(chan.state.local_config, chan.state.remote_config) script = funding_output_script(chan.local_config, chan.remote_config)
funding_address = redeem_script_to_address('p2wsh', script) funding_address = redeem_script_to_address('p2wsh', script)
self.watched_channels[funding_address] = chan, callback self.watched_channels[funding_address] = chan, callback
self.network.subscribe_to_addresses([funding_address], self.on_address_status) self.network.subscribe_to_addresses([funding_address], self.on_address_status)

101
lib/lnworker.py

@ -10,67 +10,12 @@ from . import constants
from .bitcoin import sha256, COIN from .bitcoin import sha256, COIN
from .util import bh2u, bfh, PrintError from .util import bh2u, bfh, PrintError
from .constants import set_testnet, set_simnet from .constants import set_testnet, set_simnet
from .lnbase import Peer, Outpoint, ChannelConfig, LocalState, RemoteState, Keypair, OnlyPubkeyKeypair, OpenChannel, ChannelConstraints, RevocationStore, calc_short_channel_id, privkey_to_pubkey from .lnbase import Peer, calc_short_channel_id, privkey_to_pubkey
from .lightning_payencode.lnaddr import lnencode, LnAddr, lndecode from .lightning_payencode.lnaddr import lnencode, LnAddr, lndecode
from .ecc import ECPrivkey, CURVE_ORDER, der_sig_from_sig_string from .ecc import ECPrivkey, CURVE_ORDER, der_sig_from_sig_string
from .transaction import Transaction from .transaction import Transaction
from .lnhtlc import HTLCStateMachine from .lnhtlc import HTLCStateMachine
from .lnbase import Outpoint
is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key")
def maybeDecode(k, v):
if k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature"] and v is not None:
return binascii.unhexlify(v)
return v
def decodeAll(v):
return {i: maybeDecode(i, j) for i, j in v.items()} if isinstance(v, dict) else v
def typeWrap(k, v, local):
if is_key(k):
if local:
return Keypair(**v)
else:
return OnlyPubkeyKeypair(**v)
return v
def reconstruct_namedtuples(openingchannel):
openingchannel = decodeAll(openingchannel)
openingchannel=OpenChannel(**openingchannel)
openingchannel = openingchannel._replace(funding_outpoint=Outpoint(**openingchannel.funding_outpoint))
new_local_config = {k: typeWrap(k, decodeAll(v), True) for k, v in openingchannel.local_config.items()}
openingchannel = openingchannel._replace(local_config=ChannelConfig(**new_local_config))
new_remote_config = {k: typeWrap(k, decodeAll(v), False) for k, v in openingchannel.remote_config.items()}
openingchannel = openingchannel._replace(remote_config=ChannelConfig(**new_remote_config))
new_local_state = decodeAll(openingchannel.local_state)
openingchannel = openingchannel._replace(local_state=LocalState(**new_local_state))
new_remote_state = decodeAll(openingchannel.remote_state)
new_remote_state["revocation_store"] = RevocationStore.from_json_obj(new_remote_state["revocation_store"])
openingchannel = openingchannel._replace(remote_state=RemoteState(**new_remote_state))
openingchannel = openingchannel._replace(constraints=ChannelConstraints(**openingchannel.constraints))
return openingchannel
def serialize_channels(channels_dict):
serialized_channels = []
for chan in channels_dict.values():
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
serialized_channels.append({k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in chan.state._asdict().items()})
class MyJsonEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, bytes):
return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
return super(MyJsonEncoder, self)
dumped = MyJsonEncoder().encode(serialized_channels)
roundtripped = json.loads(dumped)
reconstructed = set(reconstruct_namedtuples(x) for x in roundtripped)
if reconstructed != set(x.state for x in channels_dict.values()):
raise Exception("Channels did not roundtrip serialization without changes:\n" + repr(reconstructed) + "\n" + repr(channels_dict))
return roundtripped
# hardcoded nodes # hardcoded nodes
node_list = [ node_list = [
@ -91,7 +36,7 @@ class LNWorker(PrintError):
self.pubkey = ECPrivkey(self.privkey).get_public_key_bytes() self.pubkey = ECPrivkey(self.privkey).get_public_key_bytes()
self.config = network.config self.config = network.config
self.peers = {} self.peers = {}
self.channels = {x.channel_id: HTLCStateMachine(x) for x in map(reconstruct_namedtuples, wallet.storage.get("channels", []))} self.channels = {x.channel_id: x for x in map(HTLCStateMachine, wallet.storage.get("channels", []))}
self.invoices = wallet.storage.get('lightning_invoices', {}) self.invoices = wallet.storage.get('lightning_invoices', {})
peer_list = network.config.get('lightning_peers', node_list) peer_list = network.config.get('lightning_peers', node_list)
self.channel_state = {chan.channel_id: "DISCONNECTED" for chan in self.channels.values()} self.channel_state = {chan.channel_id: "DISCONNECTED" for chan in self.channels.values()}
@ -105,7 +50,7 @@ class LNWorker(PrintError):
def channels_for_peer(self, node_id): def channels_for_peer(self, node_id):
assert type(node_id) is bytes assert type(node_id) is bytes
return {x: y for (x, y) in self.channels.items() if y.state.node_id == node_id} return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
def add_peer(self, host, port, node_id): def add_peer(self, host, port, node_id):
peer = Peer(self, host, int(port), node_id, request_initial_sync=self.config.get("request_initial_sync", True)) peer = Peer(self, host, int(port), node_id, request_initial_sync=self.config.get("request_initial_sync", True))
@ -120,12 +65,12 @@ class LNWorker(PrintError):
self.channels[openchannel.channel_id] = openchannel self.channels[openchannel.channel_id] = openchannel
for node_id, peer in self.peers.items(): for node_id, peer in self.peers.items():
peer.channels = self.channels_for_peer(node_id) peer.channels = self.channels_for_peer(node_id)
if openchannel.state.remote_state.next_per_commitment_point == openchannel.state.remote_state.current_per_commitment_point: if openchannel.remote_state.next_per_commitment_point == openchannel.remote_state.current_per_commitment_point:
raise Exception("Tried to save channel with next_point == current_point, this should not happen") raise Exception("Tried to save channel with next_point == current_point, this should not happen")
dumped = serialize_channels(self.channels) dumped = [x.serialize() for x in self.channels.values()]
self.wallet.storage.put("channels", dumped) self.wallet.storage.put("channels", dumped)
self.wallet.storage.write() self.wallet.storage.write()
self.network.trigger_callback('channel', openchannel.state) self.network.trigger_callback('channel', openchannel)
def save_short_chan_id(self, chan): def save_short_chan_id(self, chan):
""" """
@ -134,31 +79,31 @@ class LNWorker(PrintError):
If the Funding TX has not been mined, return None If the Funding TX has not been mined, return None
""" """
assert self.channel_state[chan.channel_id] in ["OPEN", "OPENING"] assert self.channel_state[chan.channel_id] in ["OPEN", "OPENING"]
peer = self.peers[chan.state.node_id] peer = self.peers[chan.node_id]
conf = self.wallet.get_tx_height(chan.state.funding_outpoint.txid)[1] conf = self.wallet.get_tx_height(chan.funding_outpoint.txid)[1]
if conf >= chan.state.constraints.funding_txn_minimum_depth: if conf >= chan.constraints.funding_txn_minimum_depth:
block_height, tx_pos = self.wallet.get_txpos(chan.state.funding_outpoint.txid) block_height, tx_pos = self.wallet.get_txpos(chan.funding_outpoint.txid)
if tx_pos == -1: if tx_pos == -1:
self.print_error('funding tx is not yet SPV verified.. but there are ' self.print_error('funding tx is not yet SPV verified.. but there are '
'already enough confirmations (currently {})'.format(conf)) 'already enough confirmations (currently {})'.format(conf))
return False return False
chan.state = chan.state._replace(short_channel_id = calc_short_channel_id(block_height, tx_pos, chan.state.funding_outpoint.output_index)) chan.short_channel_id = calc_short_channel_id(block_height, tx_pos, chan.funding_outpoint.output_index)
self.save_channel(chan) self.save_channel(chan)
return True return True
return False return False
def on_channel_utxos(self, chan, utxos): def on_channel_utxos(self, chan, utxos):
outpoints = [Outpoint(x["tx_hash"], x["tx_pos"]) for x in utxos] outpoints = [Outpoint(x["tx_hash"], x["tx_pos"]) for x in utxos]
if chan.state.funding_outpoint not in outpoints: if chan.funding_outpoint not in outpoints:
self.channel_state[chan.channel_id] = "CLOSED" self.channel_state[chan.channel_id] = "CLOSED"
elif self.channel_state[chan.channel_id] == 'DISCONNECTED': elif self.channel_state[chan.channel_id] == 'DISCONNECTED':
peer = self.peers[chan.state.node_id] peer = self.peers[chan.node_id]
coro = peer.reestablish_channel(chan) coro = peer.reestablish_channel(chan)
asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
def on_network_update(self, event, *args): def on_network_update(self, event, *args):
for chan in self.channels.values(): for chan in self.channels.values():
peer = self.peers[chan.state.node_id] peer = self.peers[chan.node_id]
if self.channel_state[chan.channel_id] == "OPENING": if self.channel_state[chan.channel_id] == "OPENING":
res = self.save_short_chan_id(chan) res = self.save_short_chan_id(chan)
if not res: if not res:
@ -167,7 +112,7 @@ class LNWorker(PrintError):
# this results in the channel being marked OPEN # this results in the channel being marked OPEN
peer.funding_locked(chan) peer.funding_locked(chan)
elif self.channel_state[chan.channel_id] == "OPEN": elif self.channel_state[chan.channel_id] == "OPEN":
conf = self.wallet.get_tx_height(chan.state.funding_outpoint.txid)[1] conf = self.wallet.get_tx_height(chan.funding_outpoint.txid)[1]
peer.on_network_update(chan, conf) peer.on_network_update(chan, conf)
async def _open_channel_coroutine(self, node_id, amount_sat, push_sat, password): async def _open_channel_coroutine(self, node_id, amount_sat, push_sat, password):
@ -200,7 +145,7 @@ class LNWorker(PrintError):
node_id, short_channel_id = path[0] node_id, short_channel_id = path[0]
peer = self.peers[node_id] peer = self.peers[node_id]
for chan in self.channels.values(): for chan in self.channels.values():
if chan.state.short_channel_id == short_channel_id: if chan.short_channel_id == short_channel_id:
break break
else: else:
raise Exception("ChannelDB returned path with short_channel_id that is not in channel list") raise Exception("ChannelDB returned path with short_channel_id that is not in channel list")
@ -228,18 +173,18 @@ class LNWorker(PrintError):
self.wallet.storage.write() self.wallet.storage.write()
def list_channels(self): def list_channels(self):
return serialize_channels(self.channels) return [str(x) for x in self.channels]
def close_channel(self, chan_id): def close_channel(self, chan_id):
chan = self.channels[chan_id] chan = self.channels[chan_id]
# local_commitment always gives back the next expected local_commitment, # local_commitment always gives back the next expected local_commitment,
# but in this case, we want the current one. So substract one ctn number # but in this case, we want the current one. So substract one ctn number
old_state = chan.state old_local_state = chan.local_state
chan.state = chan.state._replace(local_state=chan.state.local_state._replace(ctn=chan.state.local_state.ctn - 1)) chan.local_state=chan.local_state._replace(ctn=chan.local_state.ctn - 1)
tx = chan.local_commitment tx = chan.local_commitment
chan.state = old_state chan.local_state = old_local_state
tx.sign({bh2u(chan.state.local_config.multisig_key.pubkey): (chan.state.local_config.multisig_key.privkey, True)}) tx.sign({bh2u(chan.local_config.multisig_key.pubkey): (chan.local_config.multisig_key.privkey, True)})
remote_sig = chan.state.local_state.current_commitment_signature remote_sig = chan.local_state.current_commitment_signature
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01" remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
none_idx = tx._inputs[0]["signatures"].index(None) none_idx = tx._inputs[0]["signatures"].index(None)
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig)) tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))

32
lib/tests/test_lnhtlc.py

@ -36,13 +36,13 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
max_accepted_htlcs=5 max_accepted_htlcs=5
) )
return lnbase.OpenChannel( return {
channel_id=channel_id, "channel_id":channel_id,
short_channel_id=channel_id[:8], "short_channel_id":channel_id[:8],
funding_outpoint=lnbase.Outpoint(funding_txid, funding_index), "funding_outpoint":lnbase.Outpoint(funding_txid, funding_index),
local_config=local_config, "local_config":local_config,
remote_config=remote_config, "remote_config":remote_config,
remote_state=lnbase.RemoteState( "remote_state":lnbase.RemoteState(
ctn = 0, ctn = 0,
next_per_commitment_point=nex, next_per_commitment_point=nex,
current_per_commitment_point=cur, current_per_commitment_point=cur,
@ -50,7 +50,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
revocation_store=their_revocation_store, revocation_store=their_revocation_store,
next_htlc_id = 0 next_htlc_id = 0
), ),
local_state=lnbase.LocalState( "local_state":lnbase.LocalState(
ctn = 0, ctn = 0,
per_commitment_secret_seed=seed, per_commitment_secret_seed=seed,
amount_msat=local_amount, amount_msat=local_amount,
@ -59,9 +59,9 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
was_announced=False, was_announced=False,
current_commitment_signature=None current_commitment_signature=None
), ),
constraints=lnbase.ChannelConstraints(capacity=funding_sat, feerate=local_feerate, is_initiator=is_initiator, funding_txn_minimum_depth=3), "constraints":lnbase.ChannelConstraints(capacity=funding_sat, feerate=local_feerate, is_initiator=is_initiator, funding_txn_minimum_depth=3),
node_id=other_node_id "node_id":other_node_id
) }
def bip32(sequence): def bip32(sequence):
xprv, xpub = bitcoin.bip32_root(b"9dk", 'standard') xprv, xpub = bitcoin.bip32_root(b"9dk", 'standard')
@ -184,8 +184,8 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
self.assertEqual(alice_channel.total_msat_received, bobSent, "alice has incorrect milli-satoshis received") self.assertEqual(alice_channel.total_msat_received, bobSent, "alice has incorrect milli-satoshis received")
self.assertEqual(bob_channel.total_msat_sent, bobSent, "bob has incorrect milli-satoshis sent") self.assertEqual(bob_channel.total_msat_sent, bobSent, "bob has incorrect milli-satoshis sent")
self.assertEqual(bob_channel.total_msat_received, aliceSent, "bob has incorrect milli-satoshis received") self.assertEqual(bob_channel.total_msat_received, aliceSent, "bob has incorrect milli-satoshis received")
self.assertEqual(bob_channel.state.local_state.ctn, 1, "bob has incorrect commitment height") self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height")
self.assertEqual(alice_channel.state.local_state.ctn, 1, "alice has incorrect commitment height") self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height")
# Both commitment transactions should have three outputs, and one of # Both commitment transactions should have three outputs, and one of
# them should be exactly the amount of the HTLC. # them should be exactly the amount of the HTLC.
@ -238,7 +238,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
paymentPreimage = b"\x01" * 32 paymentPreimage = b"\x01" * 32
paymentHash = bitcoin.sha256(paymentPreimage) paymentHash = bitcoin.sha256(paymentPreimage)
fee_per_kw = alice_channel.state.constraints.feerate fee_per_kw = alice_channel.constraints.feerate
self.assertEqual(fee_per_kw, 6000) self.assertEqual(fee_per_kw, 6000)
htlcAmt = 500 + lnbase.HTLC_TIMEOUT_WEIGHT * (fee_per_kw // 1000) htlcAmt = 500 + lnbase.HTLC_TIMEOUT_WEIGHT * (fee_per_kw // 1000)
self.assertEqual(htlcAmt, 4478) self.assertEqual(htlcAmt, 4478)
@ -283,9 +283,9 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
alice_sig, alice_htlc_sigs = alice_channel.sign_next_commitment() alice_sig, alice_htlc_sigs = alice_channel.sign_next_commitment()
bob_channel.receive_new_commitment(alice_sig, alice_htlc_sigs) bob_channel.receive_new_commitment(alice_sig, alice_htlc_sigs)
self.assertNotEqual(fee, alice_channel.state.constraints.feerate) self.assertNotEqual(fee, alice_channel.constraints.feerate)
rev, _ = alice_channel.revoke_current_commitment() rev, _ = alice_channel.revoke_current_commitment()
self.assertEqual(fee, alice_channel.state.constraints.feerate) self.assertEqual(fee, alice_channel.constraints.feerate)
bob_channel.receive_revocation(rev) bob_channel.receive_revocation(rev)
def force_state_transition(chanA, chanB): def force_state_transition(chanA, chanB):

Loading…
Cancel
Save