diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 6cf264ab6..4b6926b42 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -219,13 +219,14 @@ class Type(FieldSet): ] # Some BOLT types are re-typed based on their field name - # ('fieldname partial', 'original type'): ('true type', 'collapse array?') + # ('fieldname partial', 'original type', 'outer type'): ('true type', 'collapse array?') name_field_map = { ('txid', 'sha256'): ('bitcoin_txid', False), ('amt', 'u64'): ('amount_msat', False), ('msat', 'u64'): ('amount_msat', False), ('satoshis', 'u64'): ('amount_sat', False), - ('node_id', 'pubkey'): ('node_id', False), + ('node_id', 'pubkey', 'channel_announcement'): ('node_id', False), + ('node_id', 'pubkey', 'node_announcement'): ('node_id', False), ('temporary_channel_id', 'u8'): ('channel_id', True), ('secret', 'u8'): ('secret', True), ('preimage', 'u8'): ('preimage', True), @@ -242,7 +243,7 @@ class Type(FieldSet): } @staticmethod - def true_type(type_name, field_name=None): + def true_type(type_name, field_name=None, outer_name=None): """ Returns 'true' type of a given type and a flag if we've remapped a variable size/array type to a single struct (an example of this is 'temporary_channel_id' which is specified @@ -252,9 +253,10 @@ class Type(FieldSet): type_name = Type.name_remap[type_name] if field_name: - for (partial, t), true_type in Type.name_field_map.items(): - if partial in field_name and t == type_name: - return true_type + for t, true_type in Type.name_field_map.items(): + if t[0] in field_name and t[1] == type_name: + if len(t) == 2 or outer_name == t[2]: + return true_type return (type_name, False) def __init__(self, name): @@ -407,13 +409,14 @@ class Master(object): def add_extension_msg(self, name, msg): self.extension_msgs[name] = msg - def add_type(self, type_name, field_name=None): + def add_type(self, type_name, field_name=None, outer_name=None): optional = False if type_name.startswith('?'): type_name = type_name[1:] optional = True # Check for special type name re-mapping - type_name, collapse_original = Type.true_type(type_name, field_name) + type_name, collapse_original = Type.true_type(type_name, field_name, + outer_name) if type_name not in self.types: self.types[type_name] = Type(type_name) @@ -522,7 +525,7 @@ def main(options, args=None, output=sys.stdout, lines=None): if not subtype: raise ValueError('Unknown subtype {} for data.\nat {}:{}' .format(tokens[1], ln, line)) - type_obj, collapse, optional = master.add_type(tokens[3], tokens[2]) + type_obj, collapse, optional = master.add_type(tokens[3], tokens[2], tokens[1]) if optional: raise ValueError('Subtypes cannot have optional fields {}.{}\n at {}:{}' .format(subtype.name, tokens[2], ln, line)) @@ -540,7 +543,7 @@ def main(options, args=None, output=sys.stdout, lines=None): comment_set = [] elif token_type == 'tlvdata': - type_obj, collapse, optional = master.add_type(tokens[4], tokens[3]) + type_obj, collapse, optional = master.add_type(tokens[4], tokens[3], tokens[1]) if optional: raise ValueError('TLV messages cannot have optional fields {}.{}\n at {}:{}' .format(tokens[2], tokens[3], ln, line)) @@ -568,7 +571,7 @@ def main(options, args=None, output=sys.stdout, lines=None): msg = master.find_message(tokens[1]) if not msg: raise ValueError('Unknown message type {}. {}:{}'.format(tokens[1], ln, line)) - type_obj, collapse, optional = master.add_type(tokens[3], tokens[2]) + type_obj, collapse, optional = master.add_type(tokens[3], tokens[2], tokens[1]) # if this is an 'extension' field*, we want to add a new 'message' type # in the future, extensions will be handled as TLV's