diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index a2c4e0196..d46b0f43b 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -44,12 +44,12 @@ def _eval_exp_with_ctx(exp, ctx: dict) -> int: return result return sum(_eval_length_term(x, ctx) for x in exp.split("+")) -def _make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: +def _make_handler(msg_name: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: """ Generate a message handler function (taking bytes) - for message type `k` with specification `v` + for message type `msg_name` with specification `v` - Check lib/lightning.json, `k` could be 'init', + Check lib/lightning.json, `msg_name` could be 'init', and `v` could be { type: 16, payload: { 'gflen': ..., ... }, ... } @@ -57,8 +57,8 @@ def _make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: Returns function taking bytes """ def handler(data: bytes) -> Tuple[str, dict]: - nonlocal k, v - ma = {} + nonlocal msg_name, v + ma = {} # map of field name -> field data; after parsing msg pos = 0 for fieldname in v["payload"]: poslenMap = v["payload"][fieldname] @@ -69,8 +69,9 @@ def _make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: length = _eval_exp_with_ctx(length, ma) ma[fieldname] = data[pos:pos+length] pos += length - assert pos == len(data), (k, pos, len(data)) - return k, ma + # BOLT-01: "MUST ignore any additional data within a message beyond the length that it expects for that type." + assert pos <= len(data), (msg_name, pos, len(data)) + return msg_name, ma return handler class LNSerializer: @@ -80,12 +81,12 @@ class LNSerializer: with open(path) as f: structured = json.loads(f.read(), object_pairs_hook=OrderedDict) - for k in structured: - v = structured[k] + for msg_name in structured: + v = structured[msg_name] # these message types are skipped since their types collide # (for example with pong, which also uses type=19) # we don't need them yet - if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]: + if msg_name in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]: continue if len(v["payload"]) == 0: continue @@ -95,11 +96,11 @@ class LNSerializer: #print("skipping", k) continue byts = num.to_bytes(2, 'big') - assert byts not in message_types, (byts, message_types[byts].__name__, k) + assert byts not in message_types, (byts, message_types[byts].__name__, msg_name) names = [x.__name__ for x in message_types.values()] - assert k + "_handler" not in names, (k, names) - message_types[byts] = _make_handler(k, v) - message_types[byts].__name__ = k + "_handler" + assert msg_name + "_handler" not in names, (msg_name, names) + message_types[byts] = _make_handler(msg_name, v) + message_types[byts].__name__ = msg_name + "_handler" assert message_types[b"\x00\x10"].__name__ == "init_handler" self.structured = structured