Browse Source

lnmsg: handle "..." as field count

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
542e33fd86
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 71
      electrum/lnmsg.py

71
electrum/lnmsg.py

@ -69,16 +69,25 @@ def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
raise Exception()
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]:
# TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
# if field_type is a numeric, we could return a list of ints?
def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]:
if not fd: raise Exception()
assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int"
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
pass
else:
raise Exception(f"unexpected field count: {count!r}")
if count == 0:
return b""
type_len = None
if field_type == 'byte':
type_len = 1
elif field_type in ('u16', 'u32', 'u64'):
if field_type == 'u16':
elif field_type in ('u8', 'u16', 'u32', 'u64'):
if field_type == 'u8':
type_len = 1
elif field_type == 'u16':
type_len = 2
elif field_type == 'u32':
type_len = 4
@ -119,22 +128,33 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes,
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
if type_len is None:
raise UnknownMsgFieldType(f"unexpected field type: {field_type!r}")
total_len = count * type_len
_assert_can_read_at_least_n_bytes(fd, total_len)
if count == "...":
total_len = -1 # read all
else:
if type_len is None:
raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
total_len = count * type_len
_assert_can_read_at_least_n_bytes(fd, total_len)
return fd.read(total_len)
def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items
def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
value: Union[bytes, int]) -> None:
if not fd: raise Exception()
assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int"
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
pass
else:
raise Exception(f"unexpected field count: {count!r}")
if count == 0:
return
type_len = None
if field_type == 'byte':
type_len = 1
elif field_type == 'u8':
type_len = 1
elif field_type == 'u16':
type_len = 2
elif field_type == 'u32':
@ -182,14 +202,16 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
if type_len is None:
raise UnknownMsgFieldType(f"unexpected fundamental type: {field_type!r}")
total_len = count * type_len
if isinstance(value, int) and (count == 1 or field_type == 'byte'):
value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
total_len = -1
if count != "...":
if type_len is None:
raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
total_len = count * type_len
if isinstance(value, int) and (count == 1 or field_type == 'byte'):
value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
if not isinstance(value, (bytes, bytearray)):
raise Exception(f"can only write bytes into fd. got: {value!r}")
if total_len != len(value):
if count != "..." and total_len != len(value):
raise Exception(f"unexpected field size. expected: {total_len}, got {len(value)}")
nbytes_written = fd.write(value)
if nbytes_written != len(value):
@ -212,11 +234,16 @@ def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int:
def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]:
"""Returns an evaluated field count, typically an int.
If allow_any is True, the return value can be a str with value=="...".
"""
if field_count_str == "":
field_count = 1
elif field_count_str == "...":
raise NotImplementedError() # TODO...
if not allow_any:
raise Exception("field count is '...' but allow_any is False")
return field_count_str
else:
try:
field_count = int(field_count_str)
@ -301,7 +328,9 @@ class LNSerializer:
field_name = row[3]
field_type = row[4]
field_count_str = row[5]
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs[tlv_record_name])
field_count = _resolve_field_count(field_count_str,
vars_dict=kwargs[tlv_record_name],
allow_any=True)
field_value = kwargs[tlv_record_name][field_name]
_write_field(fd=tlv_record_fd,
field_type=field_type,
@ -343,7 +372,9 @@ class LNSerializer:
field_name = row[3]
field_type = row[4]
field_count_str = row[5]
field_count = _resolve_field_count(field_count_str, vars_dict=parsed[tlv_record_name])
field_count = _resolve_field_count(field_count_str,
vars_dict=parsed[tlv_record_name],
allow_any=True)
#print(f">> count={field_count}. parsed={parsed}")
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,

Loading…
Cancel
Save