From 69497522630b95a632875cf87bbeec7d1d7b3242 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Fri, 13 Mar 2020 21:20:31 +0100 Subject: [PATCH] lnmsg: initial TLV implementation --- electrum/lnmsg.py | 272 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 239 insertions(+), 33 deletions(-) diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index 4e362efae..49afb7319 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -2,6 +2,7 @@ import os import csv import io from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional +from collections import OrderedDict class MalformedMsg(Exception): @@ -16,12 +17,56 @@ class UnexpectedEndOfStream(MalformedMsg): pass -def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: +class FieldEncodingNotMinimal(MalformedMsg): + pass + + +class UnknownMandatoryTLVRecordType(MalformedMsg): + pass + + +def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int: cur_pos = fd.tell() end_pos = fd.seek(0, io.SEEK_END) fd.seek(cur_pos) - if end_pos - cur_pos < n: - raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}") + return end_pos - cur_pos + + +def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: + nremaining = _num_remaining_bytes_to_read(fd) + if nremaining < n: + raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left") + + +def bigsize_from_int(i: int) -> bytes: + assert i >= 0, i + if i < 0xfd: + return int.to_bytes(i, length=1, byteorder="big", signed=False) + elif i < 0x1_0000: + return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False) + elif i < 0x1_0000_0000: + return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False) + else: + return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False) + + +def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]: + try: + first = fd.read(1)[0] + except IndexError: + return None # end of file + if first < 0xfd: + return first + elif first == 0xfd: + _assert_can_read_at_least_n_bytes(fd, 2) + return int.from_bytes(fd.read(2), byteorder="big", signed=False) + elif first == 0xfe: + _assert_can_read_at_least_n_bytes(fd, 4) + return int.from_bytes(fd.read(4), byteorder="big", signed=False) + elif first == 0xff: + _assert_can_read_at_least_n_bytes(fd, 8) + return int.from_bytes(fd.read(8), byteorder="big", signed=False) + raise Exception() def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]: @@ -32,22 +77,36 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, type_len = None if field_type == 'byte': type_len = 1 - elif field_type == 'u16': - type_len = 2 + elif field_type in ('u16', 'u32', 'u64'): + if field_type == 'u16': + type_len = 2 + elif field_type == 'u32': + type_len = 4 + else: + assert field_type == 'u64' + type_len = 8 assert count == 1, count _assert_can_read_at_least_n_bytes(fd, type_len) return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) - elif field_type == 'u32': - type_len = 4 + elif field_type in ('tu16', 'tu32', 'tu64'): + if field_type == 'tu16': + type_len = 2 + elif field_type == 'tu32': + type_len = 4 + else: + assert field_type == 'tu64' + type_len = 8 assert count == 1, count - _assert_can_read_at_least_n_bytes(fd, type_len) - return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) - elif field_type == 'u64': - type_len = 8 + raw = fd.read(type_len) + if len(raw) > 0 and raw[0] == 0x00: + raise FieldEncodingNotMinimal() + return int.from_bytes(raw, byteorder="big", signed=False) + elif field_type == 'varint': assert count == 1, count - _assert_can_read_at_least_n_bytes(fd, type_len) - return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) - # TODO tu16/tu32/tu64 + val = read_int_from_bigsize(fd) + if val is None: + raise UnexpectedEndOfStream() + return val elif field_type == 'chain_hash': type_len = 32 elif field_type == 'channel_id': @@ -82,7 +141,35 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int, type_len = 4 elif field_type == 'u64': type_len = 8 - # TODO tu16/tu32/tu64 + elif field_type in ('tu16', 'tu32', 'tu64'): + if field_type == 'tu16': + type_len = 2 + elif field_type == 'tu32': + type_len = 4 + else: + assert field_type == 'tu64' + type_len = 8 + assert count == 1, count + if isinstance(value, int): + value = int.to_bytes(value, length=type_len, byteorder="big", signed=False) + if not isinstance(value, (bytes, bytearray)): + raise Exception(f"can only write bytes into fd. got: {value!r}") + while len(value) > 0 and value[0] == 0x00: + value = value[1:] + nbytes_written = fd.write(value) + if nbytes_written != len(value): + raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") + return + elif field_type == 'varint': + assert count == 1, count + if isinstance(value, int): + value = bigsize_from_int(value) + if not isinstance(value, (bytes, bytearray)): + raise Exception(f"can only write bytes into fd. got: {value!r}") + nbytes_written = fd.write(value) + if nbytes_written != len(value): + raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") + return elif field_type == 'chain_hash': type_len = 32 elif field_type == 'channel_id': @@ -109,16 +196,55 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int, raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") +def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]: + if not fd: raise Exception() + tlv_type = _read_field(fd=fd, field_type="varint", count=1) + tlv_len = _read_field(fd=fd, field_type="varint", count=1) + tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len) + return tlv_type, tlv_val + + +def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None: + if not fd: raise Exception() + tlv_len = len(tlv_val) + _write_field(fd=fd, field_type="varint", count=1, value=tlv_type) + _write_field(fd=fd, field_type="varint", count=1, value=tlv_len) + _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) + + +def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int: + if field_count_str == "": + field_count = 1 + elif field_count_str == "...": + raise NotImplementedError() # TODO... + else: + try: + field_count = int(field_count_str) + except ValueError: + field_count = vars_dict[field_count_str] + if isinstance(field_count, (bytes, bytearray)): + field_count = int.from_bytes(field_count, byteorder="big") + assert isinstance(field_count, int) + return field_count + + class LNSerializer: def __init__(self): + # TODO msg_type could be 'int' everywhere... self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]] self.msg_type_from_name = {} # type: Dict[str, bytes] + + self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]] + self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]] + self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]] + path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv") with open(path, newline='') as f: csvreader = csv.reader(f) for row in csvreader: #print(f">>> {row!r}") if row[0] == "msgtype": + # msgtype,,[,