diff --git a/electrum/lnbase.py b/electrum/lnbase.py index 0b032ea4f..abb065286 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -37,6 +37,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED) from .lntransport import LNTransport, LNTransportBase +from .lnmsg import encode_msg, decode_msg if TYPE_CHECKING: from .lnworker import LNWorker @@ -48,153 +49,6 @@ def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[b i = int.from_bytes(funding_txid_bytes, 'big') ^ funding_index return i.to_bytes(32, 'big'), funding_txid_bytes - -message_types = {} - -def handlesingle(x, ma: dict) -> int: - """ - Evaluate a term of the simple language used - to specify lightning message field lengths. - - If `x` is an integer, it is returned as is, - otherwise it is treated as a variable and - looked up in `ma`. - - If the value in `ma` was no integer, it is - assumed big-endian bytes and decoded. - - Returns int - """ - try: - x = int(x) - except ValueError: - x = ma[x] - try: - x = int(x) - except ValueError: - x = int.from_bytes(x, byteorder='big') - return x - -def calcexp(exp, ma: dict) -> int: - """ - Evaluate simple mathematical expression given - in `exp` with variables assigned in the dict `ma` - - Returns int - """ - exp = str(exp) - if "*" in exp: - assert "+" not in exp - result = 1 - for term in exp.split("*"): - result *= handlesingle(term, ma) - return result - return sum(handlesingle(x, ma) for x in exp.split("+")) - -def make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: - """ - Generate a message handler function (taking bytes) - for message type `k` with specification `v` - - Check lib/lightning.json, `k` could be 'init', - and `v` could be - - { type: 16, payload: { 'gflen': ..., ... }, ... } - - Returns function taking bytes - """ - def handler(data: bytes) -> Tuple[str, dict]: - nonlocal k, v - ma = {} - pos = 0 - for fieldname in v["payload"]: - poslenMap = v["payload"][fieldname] - if "feature" in poslenMap and pos == len(data): - continue - #print(poslenMap["position"], ma) - assert pos == calcexp(poslenMap["position"], ma) - length = poslenMap["length"] - length = calcexp(length, ma) - ma[fieldname] = data[pos:pos+length] - pos += length - assert pos == len(data), (k, pos, len(data)) - return k, ma - return handler - -path = os.path.join(os.path.dirname(__file__), 'lightning.json') -with open(path) as f: - structured = json.loads(f.read(), object_pairs_hook=OrderedDict) - -for k in structured: - v = structured[k] - # these message types are skipped since their types collide - # (for example with pong, which also uses type=19) - # we don't need them yet - if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]: - continue - if len(v["payload"]) == 0: - continue - try: - num = int(v["type"]) - except ValueError: - #print("skipping", k) - continue - byts = num.to_bytes(2, 'big') - assert byts not in message_types, (byts, message_types[byts].__name__, k) - names = [x.__name__ for x in message_types.values()] - assert k + "_handler" not in names, (k, names) - message_types[byts] = make_handler(k, v) - message_types[byts].__name__ = k + "_handler" - -assert message_types[b"\x00\x10"].__name__ == "init_handler" - -def decode_msg(data: bytes) -> Tuple[str, dict]: - """ - Decode Lightning message by reading the first - two bytes to determine message type. - - Returns message type string and parsed message contents dict - """ - typ = data[:2] - k, parsed = message_types[typ](data[2:]) - return k, parsed - -def gen_msg(msg_type: str, **kwargs) -> bytes: - """ - Encode kwargs into a Lightning message (bytes) - of the type given in the msg_type string - """ - typ = structured[msg_type] - data = int(typ["type"]).to_bytes(2, 'big') - lengths = {} - for k in typ["payload"]: - poslenMap = typ["payload"][k] - if "feature" in poslenMap: continue - leng = calcexp(poslenMap["length"], lengths) - try: - clone = dict(lengths) - clone.update(kwargs) - leng = calcexp(poslenMap["length"], clone) - except KeyError: - pass - try: - param = kwargs[k] - except KeyError: - param = 0 - try: - if not isinstance(param, bytes): - assert isinstance(param, int), "field {} is neither bytes or int".format(k) - param = param.to_bytes(leng, 'big') - except ValueError: - raise Exception("{} does not fit in {} bytes".format(k, leng)) - lengths[k] = len(param) - if lengths[k] != leng: - raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng)) - data += param - return data - - - class Peer(PrintError): def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase, @@ -229,7 +83,7 @@ class Peer(PrintError): def send_message(self, message_name: str, **kwargs): assert type(message_name) is str self.print_error("Sending '%s'"%message_name.upper()) - self.transport.send_bytes(gen_msg(message_name, **kwargs)) + self.transport.send_bytes(encode_msg(message_name, **kwargs)) async def initialize(self): if isinstance(self.transport, LNTransport): @@ -872,7 +726,7 @@ class Peer(PrintError): else: node_ids = self.node_ids - chan_ann = gen_msg("channel_announcement", + chan_ann = encode_msg("channel_announcement", len=0, #features not set (defaults to zeros) chain_hash=constants.net.rev_genesis_bytes(), diff --git a/electrum/lnchannelverifier.py b/electrum/lnchannelverifier.py index 54e7c519a..21c74ff15 100644 --- a/electrum/lnchannelverifier.py +++ b/electrum/lnchannelverifier.py @@ -39,6 +39,7 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure from .transaction import Transaction from .interface import GracefulDisconnect from .crypto import sha256d +from .lnmsg import encode_msg if TYPE_CHECKING: from .network import Network @@ -184,7 +185,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool: - msg_bytes = lnbase.gen_msg('channel_announcement', **chan_ann) + msg_bytes = encode_msg('channel_announcement', **chan_ann) pre_hash = msg_bytes[2+256:] h = sha256d(pre_hash) pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']] @@ -196,7 +197,7 @@ def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool: def verify_sig_for_channel_update(chan_upd: dict, node_id: bytes) -> bool: - msg_bytes = lnbase.gen_msg('channel_update', **chan_upd) + msg_bytes = encode_msg('channel_update', **chan_upd) pre_hash = msg_bytes[2+64:] h = sha256d(pre_hash) sig = chan_upd['signature'] diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py new file mode 100644 index 000000000..b198168fd --- /dev/null +++ b/electrum/lnmsg.py @@ -0,0 +1,155 @@ +import json +import os +from typing import Callable, Tuple +from collections import OrderedDict + +def _eval_length_term(x, ma: dict) -> int: + """ + Evaluate a term of the simple language used + to specify lightning message field lengths. + + If `x` is an integer, it is returned as is, + otherwise it is treated as a variable and + looked up in `ma`. + + If the value in `ma` was no integer, it is + assumed big-endian bytes and decoded. + + Returns evaluated result as int + """ + try: + x = int(x) + except ValueError: + x = ma[x] + try: + x = int(x) + except ValueError: + x = int.from_bytes(x, byteorder='big') + return x + +def _eval_exp_with_ctx(exp, ctx: dict) -> int: + """ + Evaluate simple mathematical expression given + in `exp` with context (variables assigned) + from the dict `ctx`. + + Returns evaluated result as int + """ + exp = str(exp) + if "*" in exp: + assert "+" not in exp + result = 1 + for term in exp.split("*"): + result *= _eval_length_term(term, ctx) + return result + return sum(_eval_length_term(x, ctx) for x in exp.split("+")) + +def _make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: + """ + Generate a message handler function (taking bytes) + for message type `k` with specification `v` + + Check lib/lightning.json, `k` could be 'init', + and `v` could be + + { type: 16, payload: { 'gflen': ..., ... }, ... } + + Returns function taking bytes + """ + def handler(data: bytes) -> Tuple[str, dict]: + nonlocal k, v + ma = {} + pos = 0 + for fieldname in v["payload"]: + poslenMap = v["payload"][fieldname] + if "feature" in poslenMap and pos == len(data): + continue + assert pos == _eval_exp_with_ctx(poslenMap["position"], ma) + length = poslenMap["length"] + length = _eval_exp_with_ctx(length, ma) + ma[fieldname] = data[pos:pos+length] + pos += length + assert pos == len(data), (k, pos, len(data)) + return k, ma + return handler + +class LNSerializer: + def __init__(self): + message_types = {} + path = os.path.join(os.path.dirname(__file__), 'lightning.json') + with open(path) as f: + structured = json.loads(f.read(), object_pairs_hook=OrderedDict) + + for k in structured: + v = structured[k] + # these message types are skipped since their types collide + # (for example with pong, which also uses type=19) + # we don't need them yet + if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]: + continue + if len(v["payload"]) == 0: + continue + try: + num = int(v["type"]) + except ValueError: + #print("skipping", k) + continue + byts = num.to_bytes(2, 'big') + assert byts not in message_types, (byts, message_types[byts].__name__, k) + names = [x.__name__ for x in message_types.values()] + assert k + "_handler" not in names, (k, names) + message_types[byts] = _make_handler(k, v) + message_types[byts].__name__ = k + "_handler" + + assert message_types[b"\x00\x10"].__name__ == "init_handler" + self.structured = structured + self.message_types = message_types + + def encode_msg(self, msg_type : str, **kwargs) -> bytes: + """ + Encode kwargs into a Lightning message (bytes) + of the type given in the msg_type string + """ + typ = self.structured[msg_type] + data = int(typ["type"]).to_bytes(2, 'big') + lengths = {} + for k in typ["payload"]: + poslenMap = typ["payload"][k] + if "feature" in poslenMap: continue + leng = _eval_exp_with_ctx(poslenMap["length"], lengths) + try: + clone = dict(lengths) + clone.update(kwargs) + leng = _eval_exp_with_ctx(poslenMap["length"], clone) + except KeyError: + pass + try: + param = kwargs[k] + except KeyError: + param = 0 + try: + if not isinstance(param, bytes): + assert isinstance(param, int), "field {} is neither bytes or int".format(k) + param = param.to_bytes(leng, 'big') + except ValueError: + raise Exception("{} does not fit in {} bytes".format(k, leng)) + lengths[k] = len(param) + if lengths[k] != leng: + raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng)) + data += param + return data + + def decode_msg(self, data : bytes) -> Tuple[str, dict]: + """ + Decode Lightning message by reading the first + two bytes to determine message type. + + Returns message type string and parsed message contents dict + """ + typ = data[:2] + k, parsed = self.message_types[typ](data[2:]) + return k, parsed + +_inst = LNSerializer() +encode_msg = _inst.encode_msg +decode_msg = _inst.decode_msg diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py index ebddfd354..2e9a552ea 100644 --- a/electrum/tests/test_lnbase.py +++ b/electrum/tests/test_lnbase.py @@ -13,12 +13,13 @@ from electrum.lnaddr import lnencode, LnAddr, lndecode from electrum.bitcoin import COIN, sha256 from electrum.util import bh2u -from electrum.lnbase import Peer, decode_msg, gen_msg +from electrum.lnbase import Peer from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import PaymentFailure from electrum.lnrouter import ChannelDB, LNPathFinder from electrum.lnworker import LNWorker +from electrum.lnmsg import encode_msg, decode_msg from .test_lnchan import create_test_channels @@ -135,7 +136,7 @@ class NoFeaturesTransport(MockTransport): decoded = decode_msg(data) print(decoded) if decoded[0] == 'init': - self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) + self.queue.put_nowait(encode_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) class PutIntoOthersQueueTransport(MockTransport): def __init__(self):