Browse Source

lnmsg: initial TLV implementation

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
6949752263
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 272
      electrum/lnmsg.py

272
electrum/lnmsg.py

@ -2,6 +2,7 @@ import os
import csv
import io
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
from collections import OrderedDict
class MalformedMsg(Exception):
@ -16,12 +17,56 @@ class UnexpectedEndOfStream(MalformedMsg):
pass
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
class FieldEncodingNotMinimal(MalformedMsg):
pass
class UnknownMandatoryTLVRecordType(MalformedMsg):
pass
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
cur_pos = fd.tell()
end_pos = fd.seek(0, io.SEEK_END)
fd.seek(cur_pos)
if end_pos - cur_pos < n:
raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}")
return end_pos - cur_pos
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
nremaining = _num_remaining_bytes_to_read(fd)
if nremaining < n:
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
def bigsize_from_int(i: int) -> bytes:
assert i >= 0, i
if i < 0xfd:
return int.to_bytes(i, length=1, byteorder="big", signed=False)
elif i < 0x1_0000:
return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
elif i < 0x1_0000_0000:
return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
else:
return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
try:
first = fd.read(1)[0]
except IndexError:
return None # end of file
if first < 0xfd:
return first
elif first == 0xfd:
_assert_can_read_at_least_n_bytes(fd, 2)
return int.from_bytes(fd.read(2), byteorder="big", signed=False)
elif first == 0xfe:
_assert_can_read_at_least_n_bytes(fd, 4)
return int.from_bytes(fd.read(4), byteorder="big", signed=False)
elif first == 0xff:
_assert_can_read_at_least_n_bytes(fd, 8)
return int.from_bytes(fd.read(8), byteorder="big", signed=False)
raise Exception()
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]:
@ -32,22 +77,36 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes,
type_len = None
if field_type == 'byte':
type_len = 1
elif field_type == 'u16':
type_len = 2
elif field_type in ('u16', 'u32', 'u64'):
if field_type == 'u16':
type_len = 2
elif field_type == 'u32':
type_len = 4
else:
assert field_type == 'u64'
type_len = 8
assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
elif field_type == 'u32':
type_len = 4
elif field_type in ('tu16', 'tu32', 'tu64'):
if field_type == 'tu16':
type_len = 2
elif field_type == 'tu32':
type_len = 4
else:
assert field_type == 'tu64'
type_len = 8
assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
elif field_type == 'u64':
type_len = 8
raw = fd.read(type_len)
if len(raw) > 0 and raw[0] == 0x00:
raise FieldEncodingNotMinimal()
return int.from_bytes(raw, byteorder="big", signed=False)
elif field_type == 'varint':
assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
# TODO tu16/tu32/tu64
val = read_int_from_bigsize(fd)
if val is None:
raise UnexpectedEndOfStream()
return val
elif field_type == 'chain_hash':
type_len = 32
elif field_type == 'channel_id':
@ -82,7 +141,35 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
type_len = 4
elif field_type == 'u64':
type_len = 8
# TODO tu16/tu32/tu64
elif field_type in ('tu16', 'tu32', 'tu64'):
if field_type == 'tu16':
type_len = 2
elif field_type == 'tu32':
type_len = 4
else:
assert field_type == 'tu64'
type_len = 8
assert count == 1, count
if isinstance(value, int):
value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
if not isinstance(value, (bytes, bytearray)):
raise Exception(f"can only write bytes into fd. got: {value!r}")
while len(value) > 0 and value[0] == 0x00:
value = value[1:]
nbytes_written = fd.write(value)
if nbytes_written != len(value):
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
return
elif field_type == 'varint':
assert count == 1, count
if isinstance(value, int):
value = bigsize_from_int(value)
if not isinstance(value, (bytes, bytearray)):
raise Exception(f"can only write bytes into fd. got: {value!r}")
nbytes_written = fd.write(value)
if nbytes_written != len(value):
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
return
elif field_type == 'chain_hash':
type_len = 32
elif field_type == 'channel_id':
@ -109,16 +196,55 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
if not fd: raise Exception()
tlv_type = _read_field(fd=fd, field_type="varint", count=1)
tlv_len = _read_field(fd=fd, field_type="varint", count=1)
tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
return tlv_type, tlv_val
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
if not fd: raise Exception()
tlv_len = len(tlv_val)
_write_field(fd=fd, field_type="varint", count=1, value=tlv_type)
_write_field(fd=fd, field_type="varint", count=1, value=tlv_len)
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int:
if field_count_str == "":
field_count = 1
elif field_count_str == "...":
raise NotImplementedError() # TODO...
else:
try:
field_count = int(field_count_str)
except ValueError:
field_count = vars_dict[field_count_str]
if isinstance(field_count, (bytes, bytearray)):
field_count = int.from_bytes(field_count, byteorder="big")
assert isinstance(field_count, int)
return field_count
class LNSerializer:
def __init__(self):
# TODO msg_type could be 'int' everywhere...
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
self.msg_type_from_name = {} # type: Dict[str, bytes]
self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]]
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
with open(path, newline='') as f:
csvreader = csv.reader(f)
for row in csvreader:
#print(f">>> {row!r}")
if row[0] == "msgtype":
# msgtype,<msgname>,<value>[,<option>]
msg_type_name = row[1]
msg_type_int = int(row[2])
msg_type_bytes = msg_type_int.to_bytes(2, 'big')
@ -128,11 +254,106 @@ class LNSerializer:
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
self.msg_type_from_name[msg_type_name] = msg_type_bytes
elif row[0] == "msgdata":
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
assert msg_type_name == row[1]
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
elif row[0] == "tlvtype":
# tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
tlv_stream_name = row[1]
tlv_record_name = row[2]
tlv_record_type = int(row[3])
row[3] = tlv_record_type
if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
f"stream={tlv_stream_name}")
elif row[0] == "tlvdata":
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
else:
pass # TODO
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
if tlv_record_name not in kwargs:
continue
with io.BytesIO() as tlv_record_fd:
for row in scheme:
if row[0] == "tlvtype":
pass
elif row[0] == "tlvdata":
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
field_name = row[3]
field_type = row[4]
field_count_str = row[5]
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs[tlv_record_name])
field_value = kwargs[tlv_record_name][field_name]
_write_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
pass # TODO
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
parsed = {} # type: Dict[str, Dict[str, Any]]
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
last_seen_tlv_record_type = -1 # type: int
while _num_remaining_bytes_to_read(fd) > 0:
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
if not (tlv_record_type > last_seen_tlv_record_type):
raise MalformedMsg("TLV records must be monotonically increasing by type")
last_seen_tlv_record_type = tlv_record_type
try:
scheme = scheme_map[tlv_record_type]
except KeyError:
if tlv_record_type % 2 == 0:
# unknown "even" type: hard fail
raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
else:
# unknown "odd" type: skip it
continue
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
parsed[tlv_record_name] = {}
with io.BytesIO(tlv_record_val) as tlv_record_fd:
for row in scheme:
#print(f"row: {row!r}")
if row[0] == "tlvtype":
pass
elif row[0] == "tlvdata":
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
field_name = row[3]
field_type = row[4]
field_count_str = row[5]
field_count = _resolve_field_count(field_count_str, vars_dict=parsed[tlv_record_name])
#print(f">> count={field_count}. parsed={parsed}")
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
pass # TODO
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
return parsed
def encode_msg(self, msg_type: str, **kwargs) -> bytes:
"""
Encode kwargs into a Lightning message (bytes)
@ -147,20 +368,12 @@ class LNSerializer:
if row[0] == "msgtype":
pass
elif row[0] == "msgdata":
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
field_name = row[2]
field_type = row[3]
field_count_str = row[4]
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
if field_count_str == "":
field_count = 1
else:
try:
field_count = int(field_count_str)
except ValueError:
field_count = kwargs[field_count_str]
if isinstance(field_count, (bytes, bytearray)):
field_count = int.from_bytes(field_count, byteorder="big")
assert isinstance(field_count, int)
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
try:
field_value = kwargs[field_name]
except KeyError:
@ -205,14 +418,7 @@ class LNSerializer:
field_name = row[2]
field_type = row[3]
field_count_str = row[4]
if field_count_str == "":
field_count = 1
else:
try:
field_count = int(field_count_str)
except ValueError:
field_count = parsed[field_count_str]
assert isinstance(field_count, int)
field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
#print(f">> count={field_count}. parsed={parsed}")
try:
parsed[field_name] = _read_field(fd=fd,

Loading…
Cancel
Save