Browse Source

lnmsg: initial TLV implementation

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
6949752263
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 272
      electrum/lnmsg.py

272
electrum/lnmsg.py

@ -2,6 +2,7 @@ import os
import csv import csv
import io import io
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
from collections import OrderedDict
class MalformedMsg(Exception): class MalformedMsg(Exception):
@ -16,12 +17,56 @@ class UnexpectedEndOfStream(MalformedMsg):
pass pass
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: class FieldEncodingNotMinimal(MalformedMsg):
pass
class UnknownMandatoryTLVRecordType(MalformedMsg):
pass
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
cur_pos = fd.tell() cur_pos = fd.tell()
end_pos = fd.seek(0, io.SEEK_END) end_pos = fd.seek(0, io.SEEK_END)
fd.seek(cur_pos) fd.seek(cur_pos)
if end_pos - cur_pos < n: return end_pos - cur_pos
raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}")
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
nremaining = _num_remaining_bytes_to_read(fd)
if nremaining < n:
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
def bigsize_from_int(i: int) -> bytes:
assert i >= 0, i
if i < 0xfd:
return int.to_bytes(i, length=1, byteorder="big", signed=False)
elif i < 0x1_0000:
return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
elif i < 0x1_0000_0000:
return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
else:
return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
try:
first = fd.read(1)[0]
except IndexError:
return None # end of file
if first < 0xfd:
return first
elif first == 0xfd:
_assert_can_read_at_least_n_bytes(fd, 2)
return int.from_bytes(fd.read(2), byteorder="big", signed=False)
elif first == 0xfe:
_assert_can_read_at_least_n_bytes(fd, 4)
return int.from_bytes(fd.read(4), byteorder="big", signed=False)
elif first == 0xff:
_assert_can_read_at_least_n_bytes(fd, 8)
return int.from_bytes(fd.read(8), byteorder="big", signed=False)
raise Exception()
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]: def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]:
@ -32,22 +77,36 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes,
type_len = None type_len = None
if field_type == 'byte': if field_type == 'byte':
type_len = 1 type_len = 1
elif field_type == 'u16': elif field_type in ('u16', 'u32', 'u64'):
type_len = 2 if field_type == 'u16':
type_len = 2
elif field_type == 'u32':
type_len = 4
else:
assert field_type == 'u64'
type_len = 8
assert count == 1, count assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len) _assert_can_read_at_least_n_bytes(fd, type_len)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
elif field_type == 'u32': elif field_type in ('tu16', 'tu32', 'tu64'):
type_len = 4 if field_type == 'tu16':
type_len = 2
elif field_type == 'tu32':
type_len = 4
else:
assert field_type == 'tu64'
type_len = 8
assert count == 1, count assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len) raw = fd.read(type_len)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) if len(raw) > 0 and raw[0] == 0x00:
elif field_type == 'u64': raise FieldEncodingNotMinimal()
type_len = 8 return int.from_bytes(raw, byteorder="big", signed=False)
elif field_type == 'varint':
assert count == 1, count assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len) val = read_int_from_bigsize(fd)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) if val is None:
# TODO tu16/tu32/tu64 raise UnexpectedEndOfStream()
return val
elif field_type == 'chain_hash': elif field_type == 'chain_hash':
type_len = 32 type_len = 32
elif field_type == 'channel_id': elif field_type == 'channel_id':
@ -82,7 +141,35 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
type_len = 4 type_len = 4
elif field_type == 'u64': elif field_type == 'u64':
type_len = 8 type_len = 8
# TODO tu16/tu32/tu64 elif field_type in ('tu16', 'tu32', 'tu64'):
if field_type == 'tu16':
type_len = 2
elif field_type == 'tu32':
type_len = 4
else:
assert field_type == 'tu64'
type_len = 8
assert count == 1, count
if isinstance(value, int):
value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
if not isinstance(value, (bytes, bytearray)):
raise Exception(f"can only write bytes into fd. got: {value!r}")
while len(value) > 0 and value[0] == 0x00:
value = value[1:]
nbytes_written = fd.write(value)
if nbytes_written != len(value):
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
return
elif field_type == 'varint':
assert count == 1, count
if isinstance(value, int):
value = bigsize_from_int(value)
if not isinstance(value, (bytes, bytearray)):
raise Exception(f"can only write bytes into fd. got: {value!r}")
nbytes_written = fd.write(value)
if nbytes_written != len(value):
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
return
elif field_type == 'chain_hash': elif field_type == 'chain_hash':
type_len = 32 type_len = 32
elif field_type == 'channel_id': elif field_type == 'channel_id':
@ -109,16 +196,55 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
if not fd: raise Exception()
tlv_type = _read_field(fd=fd, field_type="varint", count=1)
tlv_len = _read_field(fd=fd, field_type="varint", count=1)
tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
return tlv_type, tlv_val
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
if not fd: raise Exception()
tlv_len = len(tlv_val)
_write_field(fd=fd, field_type="varint", count=1, value=tlv_type)
_write_field(fd=fd, field_type="varint", count=1, value=tlv_len)
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int:
if field_count_str == "":
field_count = 1
elif field_count_str == "...":
raise NotImplementedError() # TODO...
else:
try:
field_count = int(field_count_str)
except ValueError:
field_count = vars_dict[field_count_str]
if isinstance(field_count, (bytes, bytearray)):
field_count = int.from_bytes(field_count, byteorder="big")
assert isinstance(field_count, int)
return field_count
class LNSerializer: class LNSerializer:
def __init__(self): def __init__(self):
# TODO msg_type could be 'int' everywhere...
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]] self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
self.msg_type_from_name = {} # type: Dict[str, bytes] self.msg_type_from_name = {} # type: Dict[str, bytes]
self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]]
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv") path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
with open(path, newline='') as f: with open(path, newline='') as f:
csvreader = csv.reader(f) csvreader = csv.reader(f)
for row in csvreader: for row in csvreader:
#print(f">>> {row!r}") #print(f">>> {row!r}")
if row[0] == "msgtype": if row[0] == "msgtype":
# msgtype,<msgname>,<value>[,<option>]
msg_type_name = row[1] msg_type_name = row[1]
msg_type_int = int(row[2]) msg_type_int = int(row[2])
msg_type_bytes = msg_type_int.to_bytes(2, 'big') msg_type_bytes = msg_type_int.to_bytes(2, 'big')
@ -128,11 +254,106 @@ class LNSerializer:
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)] self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
self.msg_type_from_name[msg_type_name] = msg_type_bytes self.msg_type_from_name[msg_type_name] = msg_type_bytes
elif row[0] == "msgdata": elif row[0] == "msgdata":
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
assert msg_type_name == row[1] assert msg_type_name == row[1]
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row)) self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
elif row[0] == "tlvtype":
# tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
tlv_stream_name = row[1]
tlv_record_name = row[2]
tlv_record_type = int(row[3])
row[3] = tlv_record_type
if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
f"stream={tlv_stream_name}")
elif row[0] == "tlvdata":
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
else: else:
pass # TODO pass # TODO
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
if tlv_record_name not in kwargs:
continue
with io.BytesIO() as tlv_record_fd:
for row in scheme:
if row[0] == "tlvtype":
pass
elif row[0] == "tlvdata":
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
field_name = row[3]
field_type = row[4]
field_count_str = row[5]
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs[tlv_record_name])
field_value = kwargs[tlv_record_name][field_name]
_write_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
pass # TODO
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
parsed = {} # type: Dict[str, Dict[str, Any]]
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
last_seen_tlv_record_type = -1 # type: int
while _num_remaining_bytes_to_read(fd) > 0:
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
if not (tlv_record_type > last_seen_tlv_record_type):
raise MalformedMsg("TLV records must be monotonically increasing by type")
last_seen_tlv_record_type = tlv_record_type
try:
scheme = scheme_map[tlv_record_type]
except KeyError:
if tlv_record_type % 2 == 0:
# unknown "even" type: hard fail
raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
else:
# unknown "odd" type: skip it
continue
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
parsed[tlv_record_name] = {}
with io.BytesIO(tlv_record_val) as tlv_record_fd:
for row in scheme:
#print(f"row: {row!r}")
if row[0] == "tlvtype":
pass
elif row[0] == "tlvdata":
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
field_name = row[3]
field_type = row[4]
field_count_str = row[5]
field_count = _resolve_field_count(field_count_str, vars_dict=parsed[tlv_record_name])
#print(f">> count={field_count}. parsed={parsed}")
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
pass # TODO
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
return parsed
def encode_msg(self, msg_type: str, **kwargs) -> bytes: def encode_msg(self, msg_type: str, **kwargs) -> bytes:
""" """
Encode kwargs into a Lightning message (bytes) Encode kwargs into a Lightning message (bytes)
@ -147,20 +368,12 @@ class LNSerializer:
if row[0] == "msgtype": if row[0] == "msgtype":
pass pass
elif row[0] == "msgdata": elif row[0] == "msgdata":
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
field_name = row[2] field_name = row[2]
field_type = row[3] field_type = row[3]
field_count_str = row[4] field_count_str = row[4]
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}") #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
if field_count_str == "": field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
field_count = 1
else:
try:
field_count = int(field_count_str)
except ValueError:
field_count = kwargs[field_count_str]
if isinstance(field_count, (bytes, bytearray)):
field_count = int.from_bytes(field_count, byteorder="big")
assert isinstance(field_count, int)
try: try:
field_value = kwargs[field_name] field_value = kwargs[field_name]
except KeyError: except KeyError:
@ -205,14 +418,7 @@ class LNSerializer:
field_name = row[2] field_name = row[2]
field_type = row[3] field_type = row[3]
field_count_str = row[4] field_count_str = row[4]
if field_count_str == "": field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
field_count = 1
else:
try:
field_count = int(field_count_str)
except ValueError:
field_count = parsed[field_count_str]
assert isinstance(field_count, int)
#print(f">> count={field_count}. parsed={parsed}") #print(f">> count={field_count}. parsed={parsed}")
try: try:
parsed[field_name] = _read_field(fd=fd, parsed[field_name] = _read_field(fd=fd,

Loading…
Cancel
Save