You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

1287 lines
56 KiB

#!/usr/bin/env python3
"""
Lightning network interface for Electrum
Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
"""
from ecdsa.util import sigdecode_der, sigencode_string_canonize
from ecdsa import VerifyingKey
from ecdsa.curves import SECP256k1
import queue
import traceback
import itertools
import json
from collections import OrderedDict, defaultdict
import asyncio
import sys
import os
import time
import binascii
import hashlib
import hmac
import cryptography.hazmat.primitives.ciphers.aead as AEAD
from .bitcoin import (public_key_from_private_key, ser_to_point, point_to_ser,
string_to_number, deserialize_privkey, EC_KEY, rev_hex, int_to_hex,
push_script, script_num_to_hex,
add_number_to_script, var_int)
from . import bitcoin
from . import constants
from . import transaction
from .util import PrintError, bh2u, print_error, bfh, profiler
from .transaction import opcodes, Transaction
from collections import namedtuple, defaultdict
# hardcoded nodes
node_list = [
('ecdsa.net', '9735', '038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff'),
('77.58.162.148', '9735', '022bb78ab9df617aeaaf37f6644609abb7295fad0c20327bccd41f8d69173ccb49')
]
class LightningError(Exception):
pass
message_types = {}
def handlesingle(x, ma):
"""
Evaluate a term of the simple language used
to specify lightning message field lengths.
If `x` is an integer, it is returned as is,
otherwise it is treated as a variable and
looked up in `ma`.
It the value in `ma` was no integer, it is
assumed big-endian bytes and decoded.
Returns int
"""
try:
x = int(x)
except ValueError:
x = ma[x]
try:
x = int(x)
except ValueError:
x = int.from_bytes(x, byteorder="big")
return x
def calcexp(exp, ma):
"""
Evaluate simple mathematical expression given
in `exp` with variables assigned in the dict `ma`
Returns int
"""
exp = str(exp)
if "*" in exp:
assert "+" not in exp
result = 1
for term in exp.split("*"):
result *= handlesingle(term, ma)
return result
return sum(handlesingle(x, ma) for x in exp.split("+"))
def make_handler(k, v):
"""
Generate a message handler function (taking bytes)
for message type `k` with specification `v`
Check lib/lightning.json, `k` could be 'init',
and `v` could be
{ type: 16, payload: { 'gflen': ..., ... }, ... }
Returns function taking bytes
"""
def handler(data):
nonlocal k, v
ma = {}
pos = 0
for fieldname in v["payload"]:
poslenMap = v["payload"][fieldname]
if "feature" in poslenMap and pos==len(data):
continue
#print(poslenMap["position"], ma)
assert pos == calcexp(poslenMap["position"], ma)
length = poslenMap["length"]
length = calcexp(length, ma)
ma[fieldname] = data[pos:pos+length]
pos += length
assert pos == len(data), (k, pos, len(data))
return k, ma
return handler
path = os.path.join(os.path.dirname(__file__), 'lightning.json')
with open(path) as f:
structured = json.loads(f.read(), object_pairs_hook=OrderedDict)
for k in structured:
v = structured[k]
# these message types are skipped since their types collide
# (for example with pong, which also uses type=19)
# we don't need them yet
if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]:
continue
if len(v["payload"]) == 0:
continue
try:
num = int(v["type"])
except ValueError:
#print("skipping", k)
continue
byts = num.to_bytes(byteorder="big",length=2)
assert byts not in message_types, (byts, message_types[byts].__name__, k)
names = [x.__name__ for x in message_types.values()]
assert k + "_handler" not in names, (k, names)
message_types[byts] = make_handler(k, v)
message_types[byts].__name__ = k + "_handler"
assert message_types[b"\x00\x10"].__name__ == "init_handler"
def decode_msg(data):
"""
Decode Lightning message by reading the first
two bytes to determine message type.
Returns message type string and parsed message contents dict
"""
typ = data[:2]
k, parsed = message_types[typ](data[2:])
return k, parsed
def gen_msg(msg_type, **kwargs):
"""
Encode kwargs into a Lightning message (bytes)
of the type given in the msg_type string
"""
typ = structured[msg_type]
data = int(typ["type"]).to_bytes(byteorder="big", length=2)
lengths = {}
for k in typ["payload"]:
poslenMap = typ["payload"][k]
if "feature" in poslenMap: continue
leng = calcexp(poslenMap["length"], lengths)
try:
clone = dict(lengths)
clone.update(kwargs)
leng = calcexp(poslenMap["length"], clone)
except KeyError:
pass
try:
param = kwargs[k]
except KeyError:
param = 0
try:
if not isinstance(param, bytes):
assert isinstance(param, int), "field {} is neither bytes or int".format(k)
param = param.to_bytes(length=leng, byteorder="big")
except ValueError:
raise Exception("{} does not fit in {} bytes".format(k, leng))
lengths[k] = len(param)
if lengths[k] != leng:
raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng))
data += param
return data
def encode(n, s):
"""Return a bytestring version of the integer
value n, with a string length of s
"""
return n.to_bytes(length=s, byteorder="big")
def H256(data):
return hashlib.sha256(data).digest()
class HandshakeState(object):
prologue = b"lightning"
protocol_name = b"Noise_XK_secp256k1_ChaChaPoly_SHA256"
handshake_version = b"\x00"
def __init__(self, responder_pub):
self.responder_pub = responder_pub
self.h = H256(self.protocol_name)
self.ck = self.h
self.update(self.prologue)
self.update(self.responder_pub)
def update(self, data):
self.h = H256(self.h + data)
return self.h
def get_nonce_bytes(n):
"""BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading
zeroes and 8 bytes little endian encoded 64 bit integer.
"""
nb = b"\x00"*4
#Encode the integer as an 8 byte byte-string
nb2 = encode(n, 8)
nb2 = bytearray(nb2)
#Little-endian is required here
nb2.reverse()
return nb + nb2
def aead_encrypt(k, nonce, associated_data, data):
nonce_bytes = get_nonce_bytes(nonce)
a = AEAD.ChaCha20Poly1305(k)
return a.encrypt(nonce_bytes, data, associated_data)
def aead_decrypt(k, nonce, associated_data, data):
nonce_bytes = get_nonce_bytes(nonce)
a = AEAD.ChaCha20Poly1305(k)
#raises InvalidTag exception if it's not valid
return a.decrypt(nonce_bytes, data, associated_data)
def get_bolt8_hkdf(salt, ikm):
"""RFC5869 HKDF instantiated in the specific form
used in Lightning BOLT 8:
Extract and expand to 64 bytes using HMAC-SHA256,
with info field set to a zero length string as per BOLT8
Return as two 32 byte fields.
"""
#Extract
prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest()
assert len(prk) == 32
#Expand
info = b""
T0 = b""
T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest()
T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest()
assert len(T1 + T2) == 64
return T1, T2
def get_ecdh(priv, pub):
s = string_to_number(priv)
pk = ser_to_point(pub)
pt = point_to_ser(pk * s)
return H256(pt)
def act1_initiator_message(hs, my_privkey):
#Get a new ephemeral key
epriv, epub = create_ephemeral_key(my_privkey)
hs.update(epub)
ss = get_ecdh(epriv, hs.responder_pub)
ck2, temp_k1 = get_bolt8_hkdf(hs.ck, ss)
hs.ck = ck2
c = aead_encrypt(temp_k1, 0, hs.h, b"")
#for next step if we do it
hs.update(c)
msg = hs.handshake_version + epub + c
assert len(msg) == 50
return msg
def privkey_to_pubkey(priv):
pub = public_key_from_private_key(priv[:32], True)
return bytes.fromhex(pub)
def create_ephemeral_key(privkey):
pub = privkey_to_pubkey(privkey)
return (privkey[:32], pub)
Keypair = namedtuple("Keypair", ["pubkey", "privkey"])
Outpoint = namedtuple("Outpoint", ["txid", "output_index"])
ChannelConfig = namedtuple("ChannelConfig", [
"payment_basepoint", "multisig_key", "htlc_basepoint", "delayed_basepoint", "revocation_basepoint",
"to_self_delay", "dust_limit_sat", "max_htlc_value_in_flight_msat", "max_accepted_htlcs"])
OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_sat"])
LocalState = namedtuple("LocalState", ["ctn", "per_commitment_secret_seed", "amount_sat"])
ChannelConstraints = namedtuple("ChannelConstraints", ["feerate", "capacity", "is_initiator"])
OpenChannel = namedtuple("OpenChannel", ["channel_id", "funding_outpoint", "local_config", "remote_config", "remote_state", "local_state", "constraints"])
class KeypairGenerator:
def __init__(self, seed):
self.xprv, xpub = bitcoin.bip32_root(seed, "p2wpkh")
def get(self, keyfamily, index):
childxprv, childxpub = bitcoin.bip32_private_derivation(self.xprv, "m/", "m/42'/{keyfamily}'/{index}'".format(keyfamily=keyfamily, index=index))
_, _, _, _, child_c, child_cK = bitcoin.deserialize_xpub(childxpub)
_, _, _, _, _, k = bitcoin.deserialize_xprv(childxprv)
assert len(k) == 32
return Keypair(pubkey=child_cK, privkey=k)
def aiosafe(f):
async def f2(*args, **kwargs):
try:
return await f(*args, **kwargs)
except:
# if the loop isn't stopped
# run_forever in network.py would not return,
# the asyncioThread would not die,
# and we would block on shutdown
asyncio.get_event_loop().stop()
traceback.print_exc()
return f2
def get_obscured_ctn(ctn, local, remote):
mask = int.from_bytes(H256(local + remote)[-6:], byteorder="big")
return ctn ^ mask
def secret_to_pubkey(secret):
assert type(secret) is int
return point_to_ser(SECP256k1.generator * secret)
def derive_pubkey(basepoint, per_commitment_point):
p = ser_to_point(basepoint) + SECP256k1.generator * bitcoin.string_to_number(bitcoin.sha256(per_commitment_point + basepoint))
return point_to_ser(p)
def derive_privkey(secret, per_commitment_point):
assert type(secret) is int
basepoint = point_to_ser(SECP256k1.generator * secret)
basepoint = secret + bitcoin.string_to_number(bitcoin.sha256(per_commitment_point + basepoint))
basepoint %= SECP256k1.order
return basepoint
def derive_blinded_pubkey(basepoint, per_commitment_point):
k1 = ser_to_point(basepoint) * bitcoin.string_to_number(bitcoin.sha256(basepoint + per_commitment_point))
k2 = ser_to_point(per_commitment_point) * bitcoin.string_to_number(bitcoin.sha256(per_commitment_point + basepoint))
return point_to_ser(k1 + k2)
def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 47) -> bytes:
"""Generate per commitment secret."""
per_commitment_secret = bytearray(seed)
for bitindex in range(bits, -1, -1):
mask = 1 << bitindex
if i & mask:
per_commitment_secret[bitindex // 8] ^= 1 << (bitindex % 8)
per_commitment_secret = bytearray(bitcoin.sha256(per_commitment_secret))
return bytes(per_commitment_secret)
def overall_weight(num_htlc):
return 500 + 172 * num_htlc + 224
HTLC_TIMEOUT_WEIGHT = 663
HTLC_SUCCESS_WEIGHT = 703
def make_htlc_tx_output(amount_msat, local_feerate, revocationpubkey, local_delayedpubkey, success):
assert type(amount_msat) is int
assert type(local_feerate) is int
assert type(revocationpubkey) is bytes
assert type(local_delayedpubkey) is bytes
script = bytes([opcodes.OP_IF]) \
+ bfh(push_script(bh2u(revocationpubkey))) \
+ bytes([opcodes.OP_ELSE]) \
+ bitcoin.add_number_to_script(144) \
+ bytes([opcodes.OP_CSV, opcodes.OP_DROP]) \
+ bfh(push_script(bh2u(local_delayedpubkey))) \
+ bytes([opcodes.OP_ENDIF, opcodes.OP_CHECKSIG])
p2wsh = bitcoin.redeem_script_to_address('p2wsh', bh2u(script))
weight = HTLC_SUCCESS_WEIGHT if success else HTLC_TIMEOUT_WEIGHT
fee = local_feerate * weight
final_amount_sat = (amount_msat - fee) // 1000
assert final_amount_sat > 0
output = (bitcoin.TYPE_ADDRESS, p2wsh, final_amount_sat)
return output
def make_htlc_tx_witness(remotehtlcsig, localhtlcsig, payment_preimage, witness_script):
assert type(remotehtlcsig) is bytes
assert type(localhtlcsig) is bytes
assert type(payment_preimage) is bytes
assert type(witness_script) is bytes
return bfh(transaction.construct_witness([0, remotehtlcsig, localhtlcsig, payment_preimage, witness_script]))
def make_htlc_tx_inputs(htlc_output_txid, htlc_output_index, revocationpubkey, local_delayedpubkey, amount_msat, witness_script):
assert type(htlc_output_txid) is str
assert type(htlc_output_index) is int
assert type(revocationpubkey) is bytes
assert type(local_delayedpubkey) is bytes
assert type(amount_msat) is int
assert type(witness_script) is str
c_inputs = [{
'scriptSig': '',
'type': 'p2wsh',
'signatures': [],
'num_sig': 0,
'prevout_n': htlc_output_index,
'prevout_hash': htlc_output_txid,
'value': amount_msat // 1000,
'coinbase': False,
'sequence': 0x0,
'preimage_script': witness_script,
}]
return c_inputs
def make_htlc_tx(cltv_timeout, inputs, output):
assert type(cltv_timeout) is int
c_outputs = [output]
tx = Transaction.from_io(inputs, c_outputs, locktime=cltv_timeout, version=2)
tx.BIP_LI01_sort()
return tx
def make_offered_htlc(revocation_pubkey, remote_htlcpubkey, local_htlcpubkey, payment_hash):
assert type(revocation_pubkey) is bytes
assert type(remote_htlcpubkey) is bytes
assert type(local_htlcpubkey) is bytes
assert type(payment_hash) is bytes
return bytes([opcodes.OP_DUP, opcodes.OP_HASH160]) + bfh(push_script(bh2u(bitcoin.hash_160(revocation_pubkey))))\
+ bytes([opcodes.OP_EQUAL, opcodes.OP_IF, opcodes.OP_CHECKSIG, opcodes.OP_ELSE]) \
+ bfh(push_script(bh2u(remote_htlcpubkey)))\
+ bytes([opcodes.OP_SWAP, opcodes.OP_SIZE]) + bitcoin.add_number_to_script(32) + bytes([opcodes.OP_EQUAL, opcodes.OP_NOTIF, opcodes.OP_DROP])\
+ bitcoin.add_number_to_script(2) + bytes([opcodes.OP_SWAP]) + bfh(push_script(bh2u(local_htlcpubkey))) + bitcoin.add_number_to_script(2)\
+ bytes([opcodes.OP_CHECKMULTISIG, opcodes.OP_ELSE, opcodes.OP_HASH160])\
+ bfh(push_script(bh2u(bitcoin.ripemd(payment_hash)))) + bytes([opcodes.OP_EQUALVERIFY, opcodes.OP_CHECKSIG, opcodes.OP_ENDIF, opcodes.OP_ENDIF])
def make_received_htlc(revocation_pubkey, remote_htlcpubkey, local_htlcpubkey, payment_hash, cltv_expiry):
for i in [revocation_pubkey, remote_htlcpubkey, local_htlcpubkey, payment_hash]:
assert type(i) is bytes
assert type(cltv_expiry) is int
return bytes([opcodes.OP_DUP, opcodes.OP_HASH160]) \
+ bfh(push_script(bh2u(bitcoin.hash_160(revocation_pubkey)))) \
+ bytes([opcodes.OP_EQUAL, opcodes.OP_IF, opcodes.OP_CHECKSIG, opcodes.OP_ELSE]) \
+ bfh(push_script(bh2u(remote_htlcpubkey))) \
+ bytes([opcodes.OP_SWAP, opcodes.OP_SIZE]) \
+ bitcoin.add_number_to_script(32) \
+ bytes([opcodes.OP_EQUAL, opcodes.OP_IF, opcodes.OP_HASH160]) \
+ bfh(push_script(bh2u(bitcoin.ripemd(payment_hash)))) \
+ bytes([opcodes.OP_EQUALVERIFY]) \
+ bitcoin.add_number_to_script(2) \
+ bytes([opcodes.OP_SWAP]) \
+ bfh(push_script(bh2u(local_htlcpubkey))) \
+ bitcoin.add_number_to_script(2) \
+ bytes([opcodes.OP_CHECKMULTISIG, opcodes.OP_ELSE, opcodes.OP_DROP]) \
+ bitcoin.add_number_to_script(cltv_expiry) \
+ bytes([opcodes.OP_CLTV, opcodes.OP_DROP, opcodes.OP_CHECKSIG, opcodes.OP_ENDIF, opcodes.OP_ENDIF])
def make_commitment_using_open_channel(openchannel, ctn, for_us, pcp, local_sat, remote_sat, htlcs=[]):
conf = openchannel.local_config if for_us else openchannel.remote_config
other_conf = openchannel.local_config if not for_us else openchannel.remote_config
payment_pubkey = derive_pubkey(other_conf.payment_basepoint.pubkey, pcp)
remote_revocation_pubkey = derive_blinded_pubkey(other_conf.revocation_basepoint.pubkey, pcp)
return make_commitment(
ctn,
conf.multisig_key.pubkey,
other_conf.multisig_key.pubkey,
payment_pubkey,
openchannel.local_config.payment_basepoint.pubkey,
openchannel.remote_config.payment_basepoint.pubkey,
remote_revocation_pubkey,
derive_pubkey(conf.delayed_basepoint.pubkey, pcp),
other_conf.to_self_delay,
*openchannel.funding_outpoint,
openchannel.constraints.capacity,
local_sat,
remote_sat,
openchannel.local_config.dust_limit_sat,
openchannel.constraints.feerate,
for_us, htlcs=htlcs)
def make_commitment(ctn, local_funding_pubkey, remote_funding_pubkey, remote_payment_pubkey,
payment_basepoint, remote_payment_basepoint,
revocation_pubkey, delayed_pubkey, to_self_delay,
funding_txid, funding_pos, funding_sat,
local_amount, remote_amount,
dust_limit_sat, local_feerate, for_us, htlcs):
pubkeys = sorted([bh2u(local_funding_pubkey), bh2u(remote_funding_pubkey)])
obs = get_obscured_ctn(ctn, payment_basepoint, remote_payment_basepoint)
locktime = (0x20 << 24) + (obs & 0xffffff)
sequence = (0x80 << 24) + (obs >> 24)
print_error('locktime', locktime, hex(locktime))
# commitment tx input
c_inputs = [{
'type': 'p2wsh',
'x_pubkeys': pubkeys,
'signatures':[None, None],
'num_sig': 2,
'prevout_n': funding_pos,
'prevout_hash': funding_txid,
'value': funding_sat,
'coinbase': False,
'sequence':sequence
}]
# commitment tx outputs
local_script = bytes([opcodes.OP_IF]) + bfh(push_script(bh2u(revocation_pubkey))) + bytes([opcodes.OP_ELSE]) + add_number_to_script(to_self_delay) \
+ bytes([opcodes.OP_CSV, opcodes.OP_DROP]) + bfh(push_script(bh2u(delayed_pubkey))) + bytes([opcodes.OP_ENDIF, opcodes.OP_CHECKSIG])
local_address = bitcoin.redeem_script_to_address('p2wsh', bh2u(local_script))
remote_address = bitcoin.pubkey_to_address('p2wpkh', bh2u(remote_payment_pubkey))
# TODO trim htlc outputs here while also considering 2nd stage htlc transactions
fee = local_feerate * overall_weight(len(htlcs)) // 1000 # TODO incorrect if anything is trimmed
to_local = (bitcoin.TYPE_ADDRESS, local_address, local_amount - (fee if for_us else 0))
to_remote = (bitcoin.TYPE_ADDRESS, remote_address, remote_amount - (fee if not for_us else 0))
c_outputs = [to_local, to_remote]
for script, msat_amount in htlcs:
c_outputs += [(bitcoin.TYPE_ADDRESS, bitcoin.redeem_script_to_address('p2wsh', bh2u(script)), msat_amount // 1000)]
# trim outputs
c_outputs_filtered = list(filter(lambda x:x[2]>= dust_limit_sat, c_outputs))
assert sum(x[2] for x in c_outputs) <= funding_sat
# create commitment tx
tx = Transaction.from_io(c_inputs, c_outputs_filtered, locktime=locktime, version=2)
tx.BIP_LI01_sort()
tx.htlc_output_indices = {}
for idx, output in enumerate(c_outputs):
if output in tx.outputs():
# minus the first two outputs (to_local, to_remote)
tx.htlc_output_indices[idx - 2] = tx.outputs().index(output)
return tx
class Peer(PrintError):
def __init__(self, host, port, pubkey, request_initial_sync=False, network=None):
self.host = host
self.port = port
self.privkey = os.urandom(32) + b"\x01"
self.pubkey = pubkey
self.network = network
self.read_buffer = b''
self.ping_time = 0
self.futures = ["channel_accepted",
"funding_signed",
"local_funding_locked",
"remote_funding_locked",
"revoke_and_ack",
"commitment_signed"]
self.channel_accepted = defaultdict(asyncio.Future)
self.funding_signed = defaultdict(asyncio.Future)
self.local_funding_locked = defaultdict(asyncio.Future)
self.remote_funding_locked = defaultdict(asyncio.Future)
self.revoke_and_ack = defaultdict(asyncio.Future)
self.commitment_signed = defaultdict(asyncio.Future)
self.initialized = asyncio.Future()
self.localfeatures = (0x08 if request_initial_sync else 0)
# view of the network
self.nodes = {} # received node announcements
self.channel_db = ChannelDB()
self.path_finder = LNPathFinder(self.channel_db)
self.unfulfilled_htlcs = []
def diagnostic_name(self):
return self.host
def ping_if_required(self):
if time.time() - self.ping_time > 120:
self.send_message(gen_msg('ping', num_pong_bytes=4, byteslen=4))
self.ping_time = time.time()
def send_message(self, msg):
message_type, payload = decode_msg(msg)
self.print_error("Sending '%s'"%message_type.upper(), payload)
l = encode(len(msg), 2)
lc = aead_encrypt(self.sk, self.sn(), b'', l)
c = aead_encrypt(self.sk, self.sn(), b'', msg)
assert len(lc) == 18
assert len(c) == len(msg) + 16
self.writer.write(lc+c)
async def read_message(self):
rn_l, rk_l = self.rn()
rn_m, rk_m = self.rn()
while True:
s = await self.reader.read(2**10)
if not s:
raise Exception('connection closed')
self.read_buffer += s
if len(self.read_buffer) < 18:
continue
lc = self.read_buffer[:18]
l = aead_decrypt(rk_l, rn_l, b'', lc)
length = int.from_bytes(l, byteorder="big")
offset = 18 + length + 16
if len(self.read_buffer) < offset:
continue
c = self.read_buffer[18:offset]
self.read_buffer = self.read_buffer[offset:]
msg = aead_decrypt(rk_m, rn_m, b'', c)
return msg
async def handshake(self):
hs = HandshakeState(self.pubkey)
msg = act1_initiator_message(hs, self.privkey)
# act 1
self.writer.write(msg)
rspns = await self.reader.read(2**10)
assert len(rspns) == 50
hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
assert bytes([hver]) == hs.handshake_version
# act 2
hs.update(alice_epub)
myepriv, myepub = create_ephemeral_key(self.privkey)
ss = get_ecdh(myepriv, alice_epub)
ck, temp_k2 = get_bolt8_hkdf(hs.ck, ss)
hs.ck = ck
p = aead_decrypt(temp_k2, 0, hs.h, tag)
hs.update(tag)
# act 3
my_pubkey = privkey_to_pubkey(self.privkey)
c = aead_encrypt(temp_k2, 1, hs.h, my_pubkey)
hs.update(c)
ss = get_ecdh(self.privkey[:32], alice_epub)
ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
hs.ck = ck
t = aead_encrypt(temp_k3, 0, hs.h, b'')
self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'')
msg = hs.handshake_version + c + t
self.writer.write(msg)
# init counters
self._sn = 0
self._rn = 0
self.r_ck = ck
self.s_ck = ck
def rn(self):
o = self._rn, self.rk
self._rn += 1
if self._rn == 1000:
self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
self._rn = 0
return o
def sn(self):
o = self._sn
self._sn += 1
if self._sn == 1000:
self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
self._sn = 0
return o
def process_message(self, message):
message_type, payload = decode_msg(message)
try:
f = getattr(self, 'on_' + message_type)
except AttributeError:
self.print_error("Received '%s'" % message_type.upper(), payload)
return
# raw message is needed to check signature
if message_type=='node_announcement':
payload['raw'] = message
f(payload)
def on_error(self, payload):
for i in self.futures:
if payload["channel_id"] in getattr(self, i):
getattr(self, i)[payload["channel_id"]].set_exception(LightningError(payload["data"]))
return
self.print_error("no future found to resolve", payload)
def on_ping(self, payload):
l = int.from_bytes(payload['num_pong_bytes'], byteorder="big")
self.send_message(gen_msg('pong', byteslen=l))
def on_accept_channel(self, payload):
self.channel_accepted[payload["temporary_channel_id"]].set_result(payload)
def on_funding_signed(self, payload):
channel_id = int.from_bytes(payload['channel_id'], byteorder="big")
self.funding_signed[channel_id].set_result(payload)
def on_funding_locked(self, payload):
channel_id = int.from_bytes(payload['channel_id'], byteorder="big")
self.remote_funding_locked[channel_id].set_result(payload)
def on_node_announcement(self, payload):
pubkey = payload['node_id']
signature = payload['signature']
h = bitcoin.Hash(payload['raw'][66:])
if not bitcoin.verify_signature(pubkey, signature, h):
return False
self.s = payload['addresses']
def read(n):
data, self.s = self.s[0:n], self.s[n:]
return data
addresses = []
while self.s:
atype = ord(read(1))
if atype == 0:
pass
elif atype == 1:
ipv4_addr = '.'.join(map(lambda x: '%d'%x, read(4)))
port = int.from_bytes(read(2), byteorder="big")
x = ipv4_addr, port, binascii.hexlify(pubkey)
addresses.append((ipv4_addr, port))
elif atype == 2:
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(4)])
port = int.from_bytes(read(2), byteorder="big")
addresses.append((ipv6_addr, port))
else:
pass
continue
alias = payload['alias'].rstrip(b'\x00')
self.nodes[pubkey] = {
'alias': alias,
'addresses': addresses
}
self.print_error('node announcement', binascii.hexlify(pubkey), alias, addresses)
def on_init(self, payload):
pass
def on_channel_update(self, payload):
self.channel_db.on_channel_update(payload)
def on_channel_announcement(self, payload):
self.channel_db.on_channel_announcement(payload)
#def open_channel(self, funding_sat, push_msat):
# self.send_message(gen_msg('open_channel', funding_sat=funding_sat, push_msat=push_msat))
@aiosafe
async def main_loop(self):
self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
await self.handshake()
# send init
self.send_message(gen_msg("init", gflen=0, lflen=1, localfeatures=self.localfeatures))
# read init
msg = await self.read_message()
self.process_message(msg)
# initialized
self.initialized.set_result(msg)
# loop
while True:
self.ping_if_required()
msg = await self.read_message()
self.process_message(msg)
# close socket
self.print_error('closing lnbase')
self.writer.close()
@aiosafe
async def channel_establishment_flow(self, wallet, config, funding_sat, push_msat, temp_channel_id, keypair_generator=None):
await self.initialized
if keypair_generator is None:
keypair_generator = KeypairGenerator(b"testseed")
# see lnd/keychain/derivation.go
keyfamilymultisig = 0
keyfamilyrevocationbase = 1
keyfamilyhtlcbase = 2
keyfamilypaymentbase = 3
keyfamilydelaybase = 4
keyfamilyrevocationroot = 5
keyfamilynodekey = 6 # TODO currently unused
# amounts
local_feerate = 20000
to_self_delay = 144
dust_limit_sat = 10
local_max_htlc_value_in_flight_msat = 500000 * 1000
local_max_accepted_htlcs = 5
# key derivation
local_config=ChannelConfig(
payment_basepoint=keypair_generator.get(keyfamilypaymentbase, 0),
multisig_key=keypair_generator.get(keyfamilymultisig, 0),
htlc_basepoint=keypair_generator.get(keyfamilyhtlcbase, 0),
delayed_basepoint=keypair_generator.get(keyfamilydelaybase, 0),
revocation_basepoint=keypair_generator.get(keyfamilyrevocationbase, 0),
to_self_delay=to_self_delay,
dust_limit_sat=dust_limit_sat,
max_htlc_value_in_flight_msat=local_max_htlc_value_in_flight_msat,
max_accepted_htlcs=local_max_accepted_htlcs
)
# TODO derive this?
per_commitment_secret_seed = 0x1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100.to_bytes(32, "big")
per_commitment_secret_index = 2**48 - 1
# for the first commitment transaction
per_commitment_secret_first = get_per_commitment_secret_from_seed(per_commitment_secret_seed, per_commitment_secret_index)
per_commitment_point_first = secret_to_pubkey(int.from_bytes(
per_commitment_secret_first,
byteorder="big"))
msg = gen_msg(
"open_channel",
temporary_channel_id=temp_channel_id,
chain_hash=bytes.fromhex(rev_hex(constants.net.GENESIS)),
funding_satoshis=funding_sat,
push_msat=push_msat,
dust_limit_satoshis=dust_limit_sat,
feerate_per_kw=local_feerate,
max_accepted_htlcs=local_max_accepted_htlcs,
funding_pubkey=local_config.multisig_key.pubkey,
revocation_basepoint=local_config.revocation_basepoint.pubkey,
htlc_basepoint=local_config.htlc_basepoint.pubkey,
payment_basepoint=local_config.payment_basepoint.pubkey,
delayed_payment_basepoint=local_config.delayed_basepoint.pubkey,
first_per_commitment_point=per_commitment_point_first,
to_self_delay=to_self_delay,
max_htlc_value_in_flight_msat=local_max_htlc_value_in_flight_msat
)
#self.channel_accepted[temp_channel_id] = asyncio.Future()
self.send_message(msg)
try:
payload = await self.channel_accepted[temp_channel_id]
finally:
del self.channel_accepted[temp_channel_id]
remote_per_commitment_point = payload['first_per_commitment_point']
remote_config=ChannelConfig(
payment_basepoint=OnlyPubkeyKeypair(payload['payment_basepoint']),
multisig_key=OnlyPubkeyKeypair(payload["funding_pubkey"]),
htlc_basepoint=OnlyPubkeyKeypair(payload['htlc_basepoint']),
delayed_basepoint=OnlyPubkeyKeypair(payload['delayed_payment_basepoint']),
revocation_basepoint=OnlyPubkeyKeypair(payload['revocation_basepoint']),
to_self_delay=int.from_bytes(payload['to_self_delay'], byteorder="big"),
dust_limit_sat=int.from_bytes(payload['dust_limit_satoshis'], byteorder="big"),
max_htlc_value_in_flight_msat=int.from_bytes(payload['max_htlc_value_in_flight_msat'], "big"),
max_accepted_htlcs=payload["max_accepted_htlcs"]
)
funding_txn_minimum_depth = int.from_bytes(payload['minimum_depth'], byteorder="big")
print('remote dust limit', remote_config.dust_limit_sat)
assert remote_config.dust_limit_sat < 600
assert int.from_bytes(payload['htlc_minimum_msat'], "big") < 600 * 1000
assert remote_config.max_htlc_value_in_flight_msat >= 500 * 1000 * 1000, remote_config.max_htlc_value_in_flight_msat
self.print_error('remote delay', remote_config.to_self_delay)
self.print_error('funding_txn_minimum_depth', funding_txn_minimum_depth)
# create funding tx
pubkeys = sorted([bh2u(local_config.multisig_key.pubkey), bh2u(remote_config.multisig_key.pubkey)])
redeem_script = transaction.multisig_script(pubkeys, 2)
funding_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script)
funding_output = (bitcoin.TYPE_ADDRESS, funding_address, funding_sat)
funding_tx = wallet.mktx([funding_output], None, config, 1000)
funding_txid = funding_tx.txid()
funding_index = funding_tx.outputs().index(funding_output)
# derive keys
local_payment_pubkey = derive_pubkey(local_config.payment_basepoint.pubkey, remote_per_commitment_point)
#local_payment_privkey = derive_privkey(base_secret, remote_per_commitment_point)
remote_payment_pubkey = derive_pubkey(remote_config.payment_basepoint.pubkey, per_commitment_point_first)
revocation_pubkey = derive_blinded_pubkey(local_config.revocation_basepoint.pubkey, remote_per_commitment_point)
remote_revocation_pubkey = derive_blinded_pubkey(remote_config.revocation_basepoint.pubkey, per_commitment_point_first)
local_delayedpubkey = derive_pubkey(local_config.delayed_basepoint.pubkey, per_commitment_point_first)
remote_delayedpubkey = derive_pubkey(remote_config.delayed_basepoint.pubkey, remote_per_commitment_point)
# compute amounts
htlcs = []
to_local_msat = funding_sat*1000 - push_msat
to_remote_msat = push_msat
local_amount = to_local_msat // 1000
remote_amount = to_remote_msat // 1000
# remote commitment transaction
remote_ctx = make_commitment(
0,
remote_config.multisig_key.pubkey, local_config.multisig_key.pubkey, local_payment_pubkey,
local_config.payment_basepoint.pubkey, remote_config.payment_basepoint.pubkey,
revocation_pubkey, remote_delayedpubkey, to_self_delay,
funding_txid, funding_index, funding_sat,
remote_amount, local_amount, remote_config.dust_limit_sat, local_feerate, False, htlcs=[])
remote_ctx.sign({bh2u(local_config.multisig_key.pubkey): (local_config.multisig_key.privkey, True)})
sig_index = pubkeys.index(bh2u(local_config.multisig_key.pubkey))
sig = bytes.fromhex(remote_ctx.inputs()[0]["signatures"][sig_index])
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
sig_64 = sigencode_string_canonize(r, s, SECP256k1.generator.order())
funding_txid_bytes = bytes.fromhex(funding_txid)[::-1]
channel_id = int.from_bytes(funding_txid_bytes, byteorder="big") ^ funding_index
self.send_message(gen_msg("funding_created", temporary_channel_id=temp_channel_id, funding_txid=funding_txid_bytes, funding_output_index=funding_index, signature=sig_64))
try:
payload = await self.funding_signed[channel_id]
finally:
del self.funding_signed[channel_id]
self.print_error('received funding_signed')
remote_sig = payload['signature']
# verify remote signature
local_ctx = make_commitment(
0,
local_config.multisig_key.pubkey, remote_config.multisig_key.pubkey, remote_payment_pubkey,
local_config.payment_basepoint.pubkey, remote_config.payment_basepoint.pubkey,
remote_revocation_pubkey, local_delayedpubkey, remote_config.to_self_delay,
funding_txid, funding_index, funding_sat,
local_amount, remote_amount, dust_limit_sat, local_feerate, True, htlcs=[])
pre_hash = bitcoin.Hash(bfh(local_ctx.serialize_preimage(0)))
if not bitcoin.verify_signature(remote_config.multisig_key.pubkey, remote_sig, pre_hash):
raise Exception('verifying remote signature failed.')
# broadcast funding tx
success, _txid = self.network.broadcast(funding_tx)
assert success, success
# wait until we see confirmations
def on_network_update(event, *args):
conf = wallet.get_tx_height(funding_txid)[1]
if conf >= funding_txn_minimum_depth:
async def set_local_funding_locked_result():
try:
self.local_funding_locked[channel_id].set_result(1)
except (asyncio.InvalidStateError, KeyError) as e:
# FIXME race condition if updates come in quickly, set_result might be called multiple times
# or self.local_funding_locked[channel_id] might be deleted already
self.print_error('local_funding_locked.set_result error for channel {}: {}'.format(channel_id, e))
asyncio.run_coroutine_threadsafe(set_local_funding_locked_result(), asyncio.get_event_loop())
self.network.unregister_callback(on_network_update)
self.network.register_callback(on_network_update, ['updated']) # thread safe
try:
await self.local_funding_locked[channel_id]
finally:
del self.local_funding_locked[channel_id]
per_commitment_secret_index -= 1
per_commitment_point_second = secret_to_pubkey(int.from_bytes(
get_per_commitment_secret_from_seed(per_commitment_secret_seed, per_commitment_secret_index),
byteorder="big"))
self.send_message(gen_msg("funding_locked", channel_id=channel_id, next_per_commitment_point=per_commitment_point_second))
# wait until we receive funding_locked
try:
remote_funding_locked_msg = await self.remote_funding_locked[channel_id]
finally:
del self.remote_funding_locked[channel_id]
self.print_error('Done waiting for remote_funding_locked', remote_funding_locked_msg)
return OpenChannel(
channel_id=channel_id,
funding_outpoint=Outpoint(funding_txid, funding_index),
local_config=local_config,
remote_config=remote_config,
remote_state=RemoteState(
ctn = 0,
next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"],
amount_sat=remote_amount
),
local_state=LocalState(
ctn = 0,
per_commitment_secret_seed=per_commitment_secret_seed,
amount_sat=local_amount
),
constraints=ChannelConstraints(capacity=funding_sat, feerate=local_feerate, is_initiator=True)
)
async def receive_commitment_revoke_ack(self, openchannel, expected_received_sat, payment_preimage):
channel_id = openchannel.channel_id
local_per_commitment_secret_seed = openchannel.local_state.per_commitment_secret_seed
assert openchannel.local_state.ctn == 0
local_next_pcs_index = 2**48 - 2
remote_funding_pubkey = openchannel.remote_config.multisig_key.pubkey
remote_next_commitment_point = openchannel.remote_state.next_per_commitment_point
remote_revocation_basepoint = openchannel.remote_config.revocation_basepoint.pubkey
remote_htlc_basepoint = openchannel.remote_config.htlc_basepoint.pubkey
local_htlc_basepoint = openchannel.local_config.htlc_basepoint.pubkey
delayed_payment_basepoint = openchannel.local_config.delayed_basepoint.pubkey
revocation_basepoint = openchannel.local_config.revocation_basepoint.pubkey
remote_delayed_payment_basepoint = openchannel.remote_config.delayed_basepoint.pubkey
remote_delay = openchannel.remote_config.to_self_delay
remote_dust_limit_sat = openchannel.remote_config.dust_limit_sat
funding_privkey = openchannel.local_config.multisig_key.privkey
htlc_privkey = openchannel.local_config.htlc_basepoint.privkey
try:
commitment_signed_msg = await self.commitment_signed[channel_id]
finally:
del self.commitment_signed[channel_id]
local_next_per_commitment_secret = get_per_commitment_secret_from_seed(local_per_commitment_secret_seed, local_next_pcs_index)
local_next_per_commitment_point = secret_to_pubkey(int.from_bytes(
local_next_per_commitment_secret,
byteorder="big"))
remote_htlc_pubkey = derive_pubkey(remote_htlc_basepoint, local_next_per_commitment_point)
local_htlc_pubkey = derive_pubkey(local_htlc_basepoint, local_next_per_commitment_point)
htlc_id = int.from_bytes(self.unfulfilled_htlcs[0]["id"], "big")
assert htlc_id == 0, htlc_id
payment_hash = self.unfulfilled_htlcs[0]["payment_hash"]
cltv_expiry = int.from_bytes(self.unfulfilled_htlcs[0]["cltv_expiry"],"big")
# TODO verify sanity of their cltv expiry
amount_msat = int.from_bytes(self.unfulfilled_htlcs[0]["amount_msat"], "big")
remote_revocation_pubkey = derive_blinded_pubkey(remote_revocation_basepoint, local_next_per_commitment_point)
htlcs_in_local = [
(
make_received_htlc(remote_revocation_pubkey, remote_htlc_pubkey, local_htlc_pubkey, payment_hash, cltv_expiry),
amount_msat
)
]
new_commitment = make_commitment_using_open_channel(openchannel, 1, True, local_next_per_commitment_point,
openchannel.local_state.amount_sat,
openchannel.remote_state.amount_sat - expected_received_sat,
htlcs_in_local)
preimage_hex = new_commitment.serialize_preimage(0)
pre_hash = bitcoin.Hash(bfh(preimage_hex))
if not bitcoin.verify_signature(remote_funding_pubkey, commitment_signed_msg["signature"], pre_hash):
raise Exception('failed verifying signature of our updated commitment transaction')
htlc_sigs_len = len(commitment_signed_msg["htlc_signature"])
if htlc_sigs_len != 64:
raise Exception("unexpected number of htlc signatures: " + str(htlc_sigs_len))
# TODO verify htlc_signature
local_last_pcs_index = 2**48 - openchannel.local_state.ctn - 1
local_last_per_commitment_secret = get_per_commitment_secret_from_seed(local_per_commitment_secret_seed, local_last_pcs_index)
self.send_message(gen_msg("revoke_and_ack", channel_id=channel_id, per_commitment_secret=local_last_per_commitment_secret, next_per_commitment_point=local_next_per_commitment_point))
funding_pubkey = openchannel.local_config.multisig_key.pubkey
pubkeys = sorted([bh2u(funding_pubkey), bh2u(remote_funding_pubkey)])
their_local_htlc_pubkey = derive_pubkey(remote_htlc_basepoint, remote_next_commitment_point)
their_remote_htlc_pubkey = derive_pubkey(local_htlc_basepoint, remote_next_commitment_point)
their_remote_htlc_privkey_number = derive_privkey(int.from_bytes(htlc_privkey, "big"), remote_next_commitment_point)
their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, "big")
# TODO check payment_hash
revocation_pubkey = derive_blinded_pubkey(revocation_basepoint, remote_next_commitment_point)
htlcs_in_remote = [(make_offered_htlc(revocation_pubkey, their_remote_htlc_pubkey, their_local_htlc_pubkey, payment_hash), amount_msat)]
remote_ctx = make_commitment_using_open_channel(openchannel, 1, False, remote_next_commitment_point,
openchannel.remote_state.amount_sat - expected_received_sat, openchannel.local_state.amount_sat, htlcs_in_remote)
remote_ctx.sign({bh2u(funding_pubkey): (funding_privkey, True)})
sig_index = pubkeys.index(bh2u(funding_pubkey))
sig = bytes.fromhex(remote_ctx.inputs()[0]["signatures"][sig_index])
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
sig_64 = sigencode_string_canonize(r, s, SECP256k1.generator.order())
remote_delayedpubkey = derive_pubkey(remote_delayed_payment_basepoint, remote_next_commitment_point)
htlc_tx_output = make_htlc_tx_output(
amount_msat = amount_msat,
local_feerate = openchannel.constraints.feerate,
revocationpubkey=revocation_pubkey,
local_delayedpubkey=remote_delayedpubkey,
success = False) # timeout for the one offering an HTLC
preimage_script = htlcs_in_remote[0][0]
htlc_output_txid = remote_ctx.txid()
htlc_tx_inputs = make_htlc_tx_inputs(
remote_ctx.txid(), remote_ctx.htlc_output_indices[0],
revocationpubkey=revocation_pubkey,
local_delayedpubkey=remote_delayedpubkey,
amount_msat=amount_msat,
witness_script=bh2u(preimage_script))
htlc_tx = make_htlc_tx(cltv_expiry, inputs=htlc_tx_inputs, output=htlc_tx_output)
# htlc_sig signs the HTLC transaction that spends from THEIR commitment transaction's offered_htlc output
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey))
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
htlc_sig = sigencode_string_canonize(r, s, SECP256k1.generator.order())
self.send_message(gen_msg("commitment_signed", channel_id=channel_id, signature=sig_64, num_htlcs=1, htlc_signature=htlc_sig))
try:
revoke_and_ack_msg = await self.revoke_and_ack[channel_id]
finally:
del self.revoke_and_ack[channel_id]
# TODO check revoke_and_ack_msg contents
self.send_message(gen_msg("update_fulfill_htlc", channel_id=channel_id, id=0, payment_preimage=payment_preimage))
remote_next_commitment_point = revoke_and_ack_msg["next_per_commitment_point"]
# remote commitment transaction without htlcs
bare_ctx = make_commitment_using_open_channel(openchannel, 2, False, remote_next_commitment_point,
openchannel.remote_state.amount_sat - expected_received_sat, openchannel.local_state.amount_sat + expected_received_sat)
bare_ctx.sign({bh2u(funding_pubkey): (funding_privkey, True)})
sig_index = pubkeys.index(bh2u(funding_pubkey))
sig = bytes.fromhex(bare_ctx.inputs()[0]["signatures"][sig_index])
r, s = sigdecode_der(sig[:-1], SECP256k1.generator.order())
sig_64 = sigencode_string_canonize(r, s, SECP256k1.generator.order())
self.send_message(gen_msg("commitment_signed", channel_id=channel_id, signature=sig_64, num_htlcs=0))
try:
revoke_and_ack_msg = await self.revoke_and_ack[channel_id]
finally:
del self.revoke_and_ack[channel_id]
# TODO check revoke_and_ack results
try:
commitment_signed_msg = await self.commitment_signed[channel_id]
finally:
del self.commitment_signed[channel_id]
# TODO check commitment_signed results
local_last_per_commitment_secret = get_per_commitment_secret_from_seed(local_per_commitment_secret_seed, local_last_pcs_index - 1)
local_next_per_commitment_secret = get_per_commitment_secret_from_seed(local_per_commitment_secret_seed, local_last_pcs_index - 2)
local_next_per_commitment_point = secret_to_pubkey(int.from_bytes(
local_next_per_commitment_secret,
byteorder="big"))
self.send_message(gen_msg("revoke_and_ack", channel_id=channel_id, per_commitment_secret=local_last_per_commitment_secret, next_per_commitment_point=local_next_per_commitment_point))
def on_commitment_signed(self, payload):
self.print_error("commitment_signed", payload)
channel_id = int.from_bytes(payload['channel_id'], byteorder="big")
self.commitment_signed[channel_id].set_result(payload)
def on_update_add_htlc(self, payload):
# no onion routing for the moment: we assume we are the end node
self.print_error('on_update_add_htlc', payload)
assert self.unfulfilled_htlcs == []
self.unfulfilled_htlcs.append(payload)
def on_revoke_and_ack(self, payload):
channel_id = int.from_bytes(payload["channel_id"], "big")
self.revoke_and_ack[channel_id].set_result(payload)
# replacement for lightningCall
class LNWorker:
def __init__(self, wallet, network):
self.wallet = wallet
self.network = network
host, port, pubkey = network.config.get('lightning_peer', node_list[0])
pubkey = binascii.unhexlify(pubkey)
port = int(port)
self.peer = Peer(host, port, pubkey)
self.network.futures.append(asyncio.run_coroutine_threadsafe(self.peer.main_loop(), asyncio.get_event_loop()))
def openchannel(self):
# todo: get utxo from wallet
# submit coro to asyncio main loop
self.peer.open_channel()
class ChannelInfo(PrintError):
def __init__(self, channel_announcement_payload):
self.channel_id = channel_announcement_payload['short_channel_id']
self.node_id_1 = channel_announcement_payload['node_id_1']
self.node_id_2 = channel_announcement_payload['node_id_2']
self.capacity_sat = None
self.policy_node1 = None
self.policy_node2 = None
def set_capacity(self, capacity):
# TODO call this after looking up UTXO for funding txn on chain
self.capacity_sat = capacity
def on_channel_update(self, msg_payload):
assert self.channel_id == msg_payload['short_channel_id']
flags = int.from_bytes(msg_payload['flags'], byteorder="big")
direction = bool(flags & 1)
if direction == 0:
self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload)
else:
self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload)
self.print_error('channel update', binascii.hexlify(self.channel_id), flags)
def get_policy_for_node(self, node_id):
if node_id == self.node_id_1:
return self.policy_node1
elif node_id == self.node_id_2:
return self.policy_node2
else:
raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id))
class ChannelInfoDirectedPolicy:
def __init__(self, channel_update_payload):
self.cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
self.htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
self.fee_base_msat = channel_update_payload['fee_base_msat']
self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
class ChannelDB(PrintError):
def __init__(self):
self._id_to_channel_info = {}
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id)
def get_channel_info(self, channel_id):
return self._id_to_channel_info.get(channel_id, None)
def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node[node_id]
def on_channel_announcement(self, msg_payload):
short_channel_id = msg_payload['short_channel_id']
self.print_error('channel announcement', binascii.hexlify(short_channel_id))
channel_info = ChannelInfo(msg_payload)
self._id_to_channel_info[short_channel_id] = channel_info
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
def on_channel_update(self, msg_payload):
short_channel_id = msg_payload['short_channel_id']
try:
channel_info = self._id_to_channel_info[short_channel_id]
except KeyError:
pass # ignore channel update
else:
channel_info.on_channel_update(msg_payload)
def remove_channel(self, short_channel_id):
try:
channel_info = self._id_to_channel_info[short_channel_id]
except KeyError:
self.print_error('cannot find channel {}'.format(short_channel_id))
return
self._id_to_channel_info.pop(short_channel_id, None)
for node in (channel_info.node_id_1, channel_info.node_id_2):
try:
self._channels_for_node[node].remove(short_channel_id)
except KeyError:
pass
class LNPathFinder(PrintError):
def __init__(self, channel_db):
self.channel_db = channel_db
def _edge_cost(self, short_channel_id, start_node, payment_amt_msat):
"""Heuristic cost of going through a channel.
direction: 0 or 1. --- 0 means node_id_1 -> node_id_2
"""
channel_info = self.channel_db.get_channel_info(short_channel_id)
if channel_info is None:
return float('inf')
channel_policy = channel_info.get_policy_for_node(start_node)
cltv_expiry_delta = channel_policy.cltv_expiry_delta
htlc_minimum_msat = channel_policy.htlc_minimum_msat
fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths
if payment_amt_msat is not None:
if payment_amt_msat < htlc_minimum_msat:
return float('inf') # payment amount too little
if channel_info.capacity_sat is not None and \
payment_amt_msat // 1000 > channel_info.capacity_sat:
return float('inf') # payment amount too large
amt = payment_amt_msat or 50000 * 1000 # guess for typical payment amount
fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000
# TODO revise
# paying 10 more satoshis ~ waiting one more block
fee_cost = fee_msat / 1000 / 10
cltv_cost = cltv_expiry_delta
return cltv_cost + fee_cost + 1
@profiler
def find_path_for_payment(self, from_node_id, to_node_id, amount_msat=None):
"""Return a path between from_node_id and to_node_id.
Returns a list of (node_id, short_channel_id) representing a path.
To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1]
"""
# TODO find multiple paths??
# run Dijkstra
distance_from_start = defaultdict(lambda: float('inf'))
distance_from_start[from_node_id] = 0
prev_node = {}
nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, from_node_id))
while nodes_to_explore.qsize() > 0:
dist_to_cur_node, cur_node = nodes_to_explore.get()
if cur_node == to_node_id:
break
if dist_to_cur_node != distance_from_start[cur_node]:
# queue.PriorityQueue does not implement decrease_priority,
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(cur_node):
channel_info = self.channel_db.get_channel_info(edge_channel_id)
node1, node2 = channel_info.node_id_1, channel_info.node_id_2
neighbour = node2 if node1 == cur_node else node1
alt_dist_to_neighbour = distance_from_start[cur_node] \
+ self._edge_cost(edge_channel_id, cur_node, amount_msat)
if alt_dist_to_neighbour < distance_from_start[neighbour]:
distance_from_start[neighbour] = alt_dist_to_neighbour
prev_node[neighbour] = cur_node, edge_channel_id
nodes_to_explore.put((alt_dist_to_neighbour, neighbour))
else:
return None # no path found
# backtrack from end to start
cur_node = to_node_id
path = [(cur_node, None)]
while cur_node != from_node_id:
cur_node, edge_taken = prev_node[cur_node]
path += [(cur_node, edge_taken)]
path.reverse()
return path