|
|
@ -25,6 +25,8 @@ def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int: |
|
|
|
|
|
|
|
|
|
|
|
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: |
|
|
|
# note: it's faster to read n bytes and then check if we read n, than |
|
|
|
# to assert we can read at least n and then read n bytes. |
|
|
|
nremaining = _num_remaining_bytes_to_read(fd) |
|
|
|
if nremaining < n: |
|
|
|
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left") |
|
|
@ -50,20 +52,26 @@ def read_bigsize_int(fd: io.BytesIO) -> Optional[int]: |
|
|
|
if first < 0xfd: |
|
|
|
return first |
|
|
|
elif first == 0xfd: |
|
|
|
_assert_can_read_at_least_n_bytes(fd, 2) |
|
|
|
val = int.from_bytes(fd.read(2), byteorder="big", signed=False) |
|
|
|
buf = fd.read(2) |
|
|
|
if len(buf) != 2: |
|
|
|
raise UnexpectedEndOfStream() |
|
|
|
val = int.from_bytes(buf, byteorder="big", signed=False) |
|
|
|
if not (0xfd <= val < 0x1_0000): |
|
|
|
raise FieldEncodingNotMinimal() |
|
|
|
return val |
|
|
|
elif first == 0xfe: |
|
|
|
_assert_can_read_at_least_n_bytes(fd, 4) |
|
|
|
val = int.from_bytes(fd.read(4), byteorder="big", signed=False) |
|
|
|
buf = fd.read(4) |
|
|
|
if len(buf) != 4: |
|
|
|
raise UnexpectedEndOfStream() |
|
|
|
val = int.from_bytes(buf, byteorder="big", signed=False) |
|
|
|
if not (0x1_0000 <= val < 0x1_0000_0000): |
|
|
|
raise FieldEncodingNotMinimal() |
|
|
|
return val |
|
|
|
elif first == 0xff: |
|
|
|
_assert_can_read_at_least_n_bytes(fd, 8) |
|
|
|
val = int.from_bytes(fd.read(8), byteorder="big", signed=False) |
|
|
|
buf = fd.read(8) |
|
|
|
if len(buf) != 8: |
|
|
|
raise UnexpectedEndOfStream() |
|
|
|
val = int.from_bytes(buf, byteorder="big", signed=False) |
|
|
|
if not (0x1_0000_0000 <= val): |
|
|
|
raise FieldEncodingNotMinimal() |
|
|
|
return val |
|
|
@ -96,8 +104,10 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U |
|
|
|
assert field_type == 'u64' |
|
|
|
type_len = 8 |
|
|
|
assert count == 1, count |
|
|
|
_assert_can_read_at_least_n_bytes(fd, type_len) |
|
|
|
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) |
|
|
|
buf = fd.read(type_len) |
|
|
|
if len(buf) != type_len: |
|
|
|
raise UnexpectedEndOfStream() |
|
|
|
return int.from_bytes(buf, byteorder="big", signed=False) |
|
|
|
elif field_type in ('tu16', 'tu32', 'tu64'): |
|
|
|
if field_type == 'tu16': |
|
|
|
type_len = 2 |
|
|
@ -129,14 +139,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U |
|
|
|
type_len = 33 |
|
|
|
elif field_type == 'short_channel_id': |
|
|
|
type_len = 8 |
|
|
|
|
|
|
|
if count == "...": |
|
|
|
total_len = -1 # read all |
|
|
|
else: |
|
|
|
if type_len is None: |
|
|
|
raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") |
|
|
|
total_len = count * type_len |
|
|
|
_assert_can_read_at_least_n_bytes(fd, total_len) |
|
|
|
return fd.read(total_len) |
|
|
|
|
|
|
|
buf = fd.read(total_len) |
|
|
|
if total_len >= 0 and len(buf) != total_len: |
|
|
|
raise UnexpectedEndOfStream() |
|
|
|
return buf |
|
|
|
|
|
|
|
|
|
|
|
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items |
|
|
|