Browse Source

lnmsg: small speed-up: read first, check length after

this saves around ~13% wall clock time in ChannelDB.load_data
hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
eecdd056b3
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 34
      electrum/lnmsg.py

34
electrum/lnmsg.py

@ -25,6 +25,8 @@ def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
# note: it's faster to read n bytes and then check if we read n, than
# to assert we can read at least n and then read n bytes.
nremaining = _num_remaining_bytes_to_read(fd)
if nremaining < n:
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
@ -50,20 +52,26 @@ def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
if first < 0xfd:
return first
elif first == 0xfd:
_assert_can_read_at_least_n_bytes(fd, 2)
val = int.from_bytes(fd.read(2), byteorder="big", signed=False)
buf = fd.read(2)
if len(buf) != 2:
raise UnexpectedEndOfStream()
val = int.from_bytes(buf, byteorder="big", signed=False)
if not (0xfd <= val < 0x1_0000):
raise FieldEncodingNotMinimal()
return val
elif first == 0xfe:
_assert_can_read_at_least_n_bytes(fd, 4)
val = int.from_bytes(fd.read(4), byteorder="big", signed=False)
buf = fd.read(4)
if len(buf) != 4:
raise UnexpectedEndOfStream()
val = int.from_bytes(buf, byteorder="big", signed=False)
if not (0x1_0000 <= val < 0x1_0000_0000):
raise FieldEncodingNotMinimal()
return val
elif first == 0xff:
_assert_can_read_at_least_n_bytes(fd, 8)
val = int.from_bytes(fd.read(8), byteorder="big", signed=False)
buf = fd.read(8)
if len(buf) != 8:
raise UnexpectedEndOfStream()
val = int.from_bytes(buf, byteorder="big", signed=False)
if not (0x1_0000_0000 <= val):
raise FieldEncodingNotMinimal()
return val
@ -96,8 +104,10 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
assert field_type == 'u64'
type_len = 8
assert count == 1, count
_assert_can_read_at_least_n_bytes(fd, type_len)
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
buf = fd.read(type_len)
if len(buf) != type_len:
raise UnexpectedEndOfStream()
return int.from_bytes(buf, byteorder="big", signed=False)
elif field_type in ('tu16', 'tu32', 'tu64'):
if field_type == 'tu16':
type_len = 2
@ -129,14 +139,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
if count == "...":
total_len = -1 # read all
else:
if type_len is None:
raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
total_len = count * type_len
_assert_can_read_at_least_n_bytes(fd, total_len)
return fd.read(total_len)
buf = fd.read(total_len)
if total_len >= 0 and len(buf) != total_len:
raise UnexpectedEndOfStream()
return buf
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items

Loading…
Cancel
Save