Browse Source

lnmsg: encode/decode TLVs as part of messages

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
f353e6d55c
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 33
      electrum/lnmsg.py

33
electrum/lnmsg.py

@ -337,7 +337,7 @@ class LNSerializer:
count=field_count,
value=field_value)
else:
pass # TODO
raise Exception(f"unexpected row in scheme: {row!r}")
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
@ -380,7 +380,7 @@ class LNSerializer:
field_type=field_type,
count=field_count)
else:
pass # TODO
raise Exception(f"unexpected row in scheme: {row!r}")
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
return parsed
@ -405,6 +405,11 @@ class LNSerializer:
field_count_str = row[4]
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
if field_name == "tlvs":
tlv_stream_name = field_type
if tlv_stream_name in kwargs:
self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name]))
continue
try:
field_value = kwargs[field_name]
except KeyError:
@ -413,16 +418,13 @@ class LNSerializer:
else:
field_value = 0 # default mandatory fields to zero
#print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
try:
_write_field(fd=fd,
field_type=field_type,
count=field_count,
value=field_value)
#print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
except UnknownMsgFieldType as e:
pass # TODO
_write_field(fd=fd,
field_type=field_type,
count=field_count,
value=field_value)
#print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
else:
pass # TODO
raise Exception(f"unexpected row in scheme: {row!r}")
return fd.getvalue()
def decode_msg(self, data: bytes) -> Tuple[str, dict]:
@ -450,20 +452,23 @@ class LNSerializer:
field_type = row[3]
field_count_str = row[4]
field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
if field_name == "tlvs":
tlv_stream_name = field_type
d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name)
parsed[tlv_stream_name] = d
continue
#print(f">> count={field_count}. parsed={parsed}")
try:
parsed[field_name] = _read_field(fd=fd,
field_type=field_type,
count=field_count)
except UnknownMsgFieldType as e:
pass # TODO
except UnexpectedEndOfStream as e:
if len(row) > 5:
break # optional feature field not present
else:
raise
else:
pass # TODO
raise Exception(f"unexpected row in scheme: {row!r}")
return msg_type_name, parsed

Loading…
Cancel
Save