diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index ae6ded83c..0eb0019ac 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -25,6 +25,8 @@ def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int: def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: + # note: it's faster to read n bytes and then check if we read n, than + # to assert we can read at least n and then read n bytes. nremaining = _num_remaining_bytes_to_read(fd) if nremaining < n: raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left") @@ -50,20 +52,26 @@ def read_bigsize_int(fd: io.BytesIO) -> Optional[int]: if first < 0xfd: return first elif first == 0xfd: - _assert_can_read_at_least_n_bytes(fd, 2) - val = int.from_bytes(fd.read(2), byteorder="big", signed=False) + buf = fd.read(2) + if len(buf) != 2: + raise UnexpectedEndOfStream() + val = int.from_bytes(buf, byteorder="big", signed=False) if not (0xfd <= val < 0x1_0000): raise FieldEncodingNotMinimal() return val elif first == 0xfe: - _assert_can_read_at_least_n_bytes(fd, 4) - val = int.from_bytes(fd.read(4), byteorder="big", signed=False) + buf = fd.read(4) + if len(buf) != 4: + raise UnexpectedEndOfStream() + val = int.from_bytes(buf, byteorder="big", signed=False) if not (0x1_0000 <= val < 0x1_0000_0000): raise FieldEncodingNotMinimal() return val elif first == 0xff: - _assert_can_read_at_least_n_bytes(fd, 8) - val = int.from_bytes(fd.read(8), byteorder="big", signed=False) + buf = fd.read(8) + if len(buf) != 8: + raise UnexpectedEndOfStream() + val = int.from_bytes(buf, byteorder="big", signed=False) if not (0x1_0000_0000 <= val): raise FieldEncodingNotMinimal() return val @@ -96,8 +104,10 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U assert field_type == 'u64' type_len = 8 assert count == 1, count - _assert_can_read_at_least_n_bytes(fd, type_len) - return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) + buf = fd.read(type_len) + if len(buf) != type_len: + raise UnexpectedEndOfStream() + return int.from_bytes(buf, byteorder="big", signed=False) elif field_type in ('tu16', 'tu32', 'tu64'): if field_type == 'tu16': type_len = 2 @@ -129,14 +139,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U type_len = 33 elif field_type == 'short_channel_id': type_len = 8 + if count == "...": total_len = -1 # read all else: if type_len is None: raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") total_len = count * type_len - _assert_can_read_at_least_n_bytes(fd, total_len) - return fd.read(total_len) + + buf = fd.read(total_len) + if total_len >= 0 and len(buf) != total_len: + raise UnexpectedEndOfStream() + return buf # TODO: maybe for "value" we could accept a list with len "count" of appropriate items