From f353e6d55cbad2ef41da760a2e48b869b0fb4cd3 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Fri, 13 Mar 2020 22:45:37 +0100 Subject: [PATCH] lnmsg: encode/decode TLVs as part of messages --- electrum/lnmsg.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index 824339857..82df84b7c 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -337,7 +337,7 @@ class LNSerializer: count=field_count, value=field_value) else: - pass # TODO + raise Exception(f"unexpected row in scheme: {row!r}") _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue()) def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]: @@ -380,7 +380,7 @@ class LNSerializer: field_type=field_type, count=field_count) else: - pass # TODO + raise Exception(f"unexpected row in scheme: {row!r}") if _num_remaining_bytes_to_read(tlv_record_fd) > 0: raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage") return parsed @@ -405,6 +405,11 @@ class LNSerializer: field_count_str = row[4] #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}") field_count = _resolve_field_count(field_count_str, vars_dict=kwargs) + if field_name == "tlvs": + tlv_stream_name = field_type + if tlv_stream_name in kwargs: + self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name])) + continue try: field_value = kwargs[field_name] except KeyError: @@ -413,16 +418,13 @@ class LNSerializer: else: field_value = 0 # default mandatory fields to zero #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}") - try: - _write_field(fd=fd, - field_type=field_type, - count=field_count, - value=field_value) - #print(f">>> encode_msg. so far: {fd.getvalue().hex()}") - except UnknownMsgFieldType as e: - pass # TODO + _write_field(fd=fd, + field_type=field_type, + count=field_count, + value=field_value) + #print(f">>> encode_msg. so far: {fd.getvalue().hex()}") else: - pass # TODO + raise Exception(f"unexpected row in scheme: {row!r}") return fd.getvalue() def decode_msg(self, data: bytes) -> Tuple[str, dict]: @@ -450,20 +452,23 @@ class LNSerializer: field_type = row[3] field_count_str = row[4] field_count = _resolve_field_count(field_count_str, vars_dict=parsed) + if field_name == "tlvs": + tlv_stream_name = field_type + d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name) + parsed[tlv_stream_name] = d + continue #print(f">> count={field_count}. parsed={parsed}") try: parsed[field_name] = _read_field(fd=fd, field_type=field_type, count=field_count) - except UnknownMsgFieldType as e: - pass # TODO except UnexpectedEndOfStream as e: if len(row) > 5: break # optional feature field not present else: raise else: - pass # TODO + raise Exception(f"unexpected row in scheme: {row!r}") return msg_type_name, parsed