Browse Source

move lightning message encoding to new lnmsg module

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
5f1feee331
  1. 152
      electrum/lnbase.py
  2. 5
      electrum/lnchannelverifier.py
  3. 155
      electrum/lnmsg.py
  4. 5
      electrum/tests/test_lnbase.py

152
electrum/lnbase.py

@ -37,6 +37,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED) MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED)
from .lntransport import LNTransport, LNTransportBase from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg
if TYPE_CHECKING: if TYPE_CHECKING:
from .lnworker import LNWorker from .lnworker import LNWorker
@ -48,153 +49,6 @@ def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[b
i = int.from_bytes(funding_txid_bytes, 'big') ^ funding_index i = int.from_bytes(funding_txid_bytes, 'big') ^ funding_index
return i.to_bytes(32, 'big'), funding_txid_bytes return i.to_bytes(32, 'big'), funding_txid_bytes
message_types = {}
def handlesingle(x, ma: dict) -> int:
"""
Evaluate a term of the simple language used
to specify lightning message field lengths.
If `x` is an integer, it is returned as is,
otherwise it is treated as a variable and
looked up in `ma`.
If the value in `ma` was no integer, it is
assumed big-endian bytes and decoded.
Returns int
"""
try:
x = int(x)
except ValueError:
x = ma[x]
try:
x = int(x)
except ValueError:
x = int.from_bytes(x, byteorder='big')
return x
def calcexp(exp, ma: dict) -> int:
"""
Evaluate simple mathematical expression given
in `exp` with variables assigned in the dict `ma`
Returns int
"""
exp = str(exp)
if "*" in exp:
assert "+" not in exp
result = 1
for term in exp.split("*"):
result *= handlesingle(term, ma)
return result
return sum(handlesingle(x, ma) for x in exp.split("+"))
def make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]:
"""
Generate a message handler function (taking bytes)
for message type `k` with specification `v`
Check lib/lightning.json, `k` could be 'init',
and `v` could be
{ type: 16, payload: { 'gflen': ..., ... }, ... }
Returns function taking bytes
"""
def handler(data: bytes) -> Tuple[str, dict]:
nonlocal k, v
ma = {}
pos = 0
for fieldname in v["payload"]:
poslenMap = v["payload"][fieldname]
if "feature" in poslenMap and pos == len(data):
continue
#print(poslenMap["position"], ma)
assert pos == calcexp(poslenMap["position"], ma)
length = poslenMap["length"]
length = calcexp(length, ma)
ma[fieldname] = data[pos:pos+length]
pos += length
assert pos == len(data), (k, pos, len(data))
return k, ma
return handler
path = os.path.join(os.path.dirname(__file__), 'lightning.json')
with open(path) as f:
structured = json.loads(f.read(), object_pairs_hook=OrderedDict)
for k in structured:
v = structured[k]
# these message types are skipped since their types collide
# (for example with pong, which also uses type=19)
# we don't need them yet
if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]:
continue
if len(v["payload"]) == 0:
continue
try:
num = int(v["type"])
except ValueError:
#print("skipping", k)
continue
byts = num.to_bytes(2, 'big')
assert byts not in message_types, (byts, message_types[byts].__name__, k)
names = [x.__name__ for x in message_types.values()]
assert k + "_handler" not in names, (k, names)
message_types[byts] = make_handler(k, v)
message_types[byts].__name__ = k + "_handler"
assert message_types[b"\x00\x10"].__name__ == "init_handler"
def decode_msg(data: bytes) -> Tuple[str, dict]:
"""
Decode Lightning message by reading the first
two bytes to determine message type.
Returns message type string and parsed message contents dict
"""
typ = data[:2]
k, parsed = message_types[typ](data[2:])
return k, parsed
def gen_msg(msg_type: str, **kwargs) -> bytes:
"""
Encode kwargs into a Lightning message (bytes)
of the type given in the msg_type string
"""
typ = structured[msg_type]
data = int(typ["type"]).to_bytes(2, 'big')
lengths = {}
for k in typ["payload"]:
poslenMap = typ["payload"][k]
if "feature" in poslenMap: continue
leng = calcexp(poslenMap["length"], lengths)
try:
clone = dict(lengths)
clone.update(kwargs)
leng = calcexp(poslenMap["length"], clone)
except KeyError:
pass
try:
param = kwargs[k]
except KeyError:
param = 0
try:
if not isinstance(param, bytes):
assert isinstance(param, int), "field {} is neither bytes or int".format(k)
param = param.to_bytes(leng, 'big')
except ValueError:
raise Exception("{} does not fit in {} bytes".format(k, leng))
lengths[k] = len(param)
if lengths[k] != leng:
raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng))
data += param
return data
class Peer(PrintError): class Peer(PrintError):
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase, def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
@ -229,7 +83,7 @@ class Peer(PrintError):
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str assert type(message_name) is str
self.print_error("Sending '%s'"%message_name.upper()) self.print_error("Sending '%s'"%message_name.upper())
self.transport.send_bytes(gen_msg(message_name, **kwargs)) self.transport.send_bytes(encode_msg(message_name, **kwargs))
async def initialize(self): async def initialize(self):
if isinstance(self.transport, LNTransport): if isinstance(self.transport, LNTransport):
@ -872,7 +726,7 @@ class Peer(PrintError):
else: else:
node_ids = self.node_ids node_ids = self.node_ids
chan_ann = gen_msg("channel_announcement", chan_ann = encode_msg("channel_announcement",
len=0, len=0,
#features not set (defaults to zeros) #features not set (defaults to zeros)
chain_hash=constants.net.rev_genesis_bytes(), chain_hash=constants.net.rev_genesis_bytes(),

5
electrum/lnchannelverifier.py

@ -39,6 +39,7 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction from .transaction import Transaction
from .interface import GracefulDisconnect from .interface import GracefulDisconnect
from .crypto import sha256d from .crypto import sha256d
from .lnmsg import encode_msg
if TYPE_CHECKING: if TYPE_CHECKING:
from .network import Network from .network import Network
@ -184,7 +185,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool: def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool:
msg_bytes = lnbase.gen_msg('channel_announcement', **chan_ann) msg_bytes = encode_msg('channel_announcement', **chan_ann)
pre_hash = msg_bytes[2+256:] pre_hash = msg_bytes[2+256:]
h = sha256d(pre_hash) h = sha256d(pre_hash)
pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']] pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']]
@ -196,7 +197,7 @@ def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool:
def verify_sig_for_channel_update(chan_upd: dict, node_id: bytes) -> bool: def verify_sig_for_channel_update(chan_upd: dict, node_id: bytes) -> bool:
msg_bytes = lnbase.gen_msg('channel_update', **chan_upd) msg_bytes = encode_msg('channel_update', **chan_upd)
pre_hash = msg_bytes[2+64:] pre_hash = msg_bytes[2+64:]
h = sha256d(pre_hash) h = sha256d(pre_hash)
sig = chan_upd['signature'] sig = chan_upd['signature']

155
electrum/lnmsg.py

@ -0,0 +1,155 @@
import json
import os
from typing import Callable, Tuple
from collections import OrderedDict
def _eval_length_term(x, ma: dict) -> int:
"""
Evaluate a term of the simple language used
to specify lightning message field lengths.
If `x` is an integer, it is returned as is,
otherwise it is treated as a variable and
looked up in `ma`.
If the value in `ma` was no integer, it is
assumed big-endian bytes and decoded.
Returns evaluated result as int
"""
try:
x = int(x)
except ValueError:
x = ma[x]
try:
x = int(x)
except ValueError:
x = int.from_bytes(x, byteorder='big')
return x
def _eval_exp_with_ctx(exp, ctx: dict) -> int:
"""
Evaluate simple mathematical expression given
in `exp` with context (variables assigned)
from the dict `ctx`.
Returns evaluated result as int
"""
exp = str(exp)
if "*" in exp:
assert "+" not in exp
result = 1
for term in exp.split("*"):
result *= _eval_length_term(term, ctx)
return result
return sum(_eval_length_term(x, ctx) for x in exp.split("+"))
def _make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]:
"""
Generate a message handler function (taking bytes)
for message type `k` with specification `v`
Check lib/lightning.json, `k` could be 'init',
and `v` could be
{ type: 16, payload: { 'gflen': ..., ... }, ... }
Returns function taking bytes
"""
def handler(data: bytes) -> Tuple[str, dict]:
nonlocal k, v
ma = {}
pos = 0
for fieldname in v["payload"]:
poslenMap = v["payload"][fieldname]
if "feature" in poslenMap and pos == len(data):
continue
assert pos == _eval_exp_with_ctx(poslenMap["position"], ma)
length = poslenMap["length"]
length = _eval_exp_with_ctx(length, ma)
ma[fieldname] = data[pos:pos+length]
pos += length
assert pos == len(data), (k, pos, len(data))
return k, ma
return handler
class LNSerializer:
def __init__(self):
message_types = {}
path = os.path.join(os.path.dirname(__file__), 'lightning.json')
with open(path) as f:
structured = json.loads(f.read(), object_pairs_hook=OrderedDict)
for k in structured:
v = structured[k]
# these message types are skipped since their types collide
# (for example with pong, which also uses type=19)
# we don't need them yet
if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]:
continue
if len(v["payload"]) == 0:
continue
try:
num = int(v["type"])
except ValueError:
#print("skipping", k)
continue
byts = num.to_bytes(2, 'big')
assert byts not in message_types, (byts, message_types[byts].__name__, k)
names = [x.__name__ for x in message_types.values()]
assert k + "_handler" not in names, (k, names)
message_types[byts] = _make_handler(k, v)
message_types[byts].__name__ = k + "_handler"
assert message_types[b"\x00\x10"].__name__ == "init_handler"
self.structured = structured
self.message_types = message_types
def encode_msg(self, msg_type : str, **kwargs) -> bytes:
"""
Encode kwargs into a Lightning message (bytes)
of the type given in the msg_type string
"""
typ = self.structured[msg_type]
data = int(typ["type"]).to_bytes(2, 'big')
lengths = {}
for k in typ["payload"]:
poslenMap = typ["payload"][k]
if "feature" in poslenMap: continue
leng = _eval_exp_with_ctx(poslenMap["length"], lengths)
try:
clone = dict(lengths)
clone.update(kwargs)
leng = _eval_exp_with_ctx(poslenMap["length"], clone)
except KeyError:
pass
try:
param = kwargs[k]
except KeyError:
param = 0
try:
if not isinstance(param, bytes):
assert isinstance(param, int), "field {} is neither bytes or int".format(k)
param = param.to_bytes(leng, 'big')
except ValueError:
raise Exception("{} does not fit in {} bytes".format(k, leng))
lengths[k] = len(param)
if lengths[k] != leng:
raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng))
data += param
return data
def decode_msg(self, data : bytes) -> Tuple[str, dict]:
"""
Decode Lightning message by reading the first
two bytes to determine message type.
Returns message type string and parsed message contents dict
"""
typ = data[:2]
k, parsed = self.message_types[typ](data[2:])
return k, parsed
_inst = LNSerializer()
encode_msg = _inst.encode_msg
decode_msg = _inst.decode_msg

5
electrum/tests/test_lnbase.py

@ -13,12 +13,13 @@ from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256 from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u from electrum.util import bh2u
from electrum.lnbase import Peer, decode_msg, gen_msg from electrum.lnbase import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure from electrum.lnutil import PaymentFailure
from electrum.lnrouter import ChannelDB, LNPathFinder from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker from electrum.lnworker import LNWorker
from electrum.lnmsg import encode_msg, decode_msg
from .test_lnchan import create_test_channels from .test_lnchan import create_test_channels
@ -135,7 +136,7 @@ class NoFeaturesTransport(MockTransport):
decoded = decode_msg(data) decoded = decode_msg(data)
print(decoded) print(decoded)
if decoded[0] == 'init': if decoded[0] == 'init':
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) self.queue.put_nowait(encode_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
class PutIntoOthersQueueTransport(MockTransport): class PutIntoOthersQueueTransport(MockTransport):
def __init__(self): def __init__(self):

Loading…
Cancel
Save