From 316edb39a434cec6c56d8d04a10bff649cf708e9 Mon Sep 17 00:00:00 2001 From: lisa neigut Date: Tue, 23 Jul 2019 16:36:55 -0500 Subject: [PATCH] bolt-gen: for wire messages, print out optional fields (if present) optional fields should be printed, if they exist. so let's print them! --- tools/gen/print_impl_template | 5 +++++ tools/generate-bolts.py | 37 +++++++++++++++++++++++------------ tools/test/test_cases | 1 + 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/tools/gen/print_impl_template b/tools/gen/print_impl_template index cd8b0c8e9..d4b50e904 100644 --- a/tools/gen/print_impl_template +++ b/tools/gen/print_impl_template @@ -33,6 +33,11 @@ void print${options.enum_name}_message(const u8 *msg) ## definition for printing field sets <%def name="print_fieldset(fields, nested, cursor, plen)"> % for f in fields: + % if f.is_extension(): + if (plen <= 0) + return; + printf("(${','.join(f.extension_names)}):"); + % endif % if f.len_field_of: ${f.type_obj.type_name()} ${f.name} = fromwire_${f.type_obj.name}(${cursor}, ${plen});${truncate_check(nested)} <% continue %> \ % endif diff --git a/tools/generate-bolts.py b/tools/generate-bolts.py index 6ed9c5a1c..f48a61991 100755 --- a/tools/generate-bolts.py +++ b/tools/generate-bolts.py @@ -38,7 +38,7 @@ def next_line(args, lines): # Class definitions, to keep things classy class Field(object): - def __init__(self, name, type_obj, extension=False, + def __init__(self, name, type_obj, extensions=[], field_comments=[], optional=False): self.name = name self.type_obj = type_obj @@ -46,7 +46,7 @@ class Field(object): self.len_field_of = None self.len_field = None - self.is_extension = extension + self.extension_names = extensions self.is_optional = optional self.field_comments = field_comments @@ -79,7 +79,7 @@ class Field(object): return self.is_optional def is_extension(self): - return self.is_extension + return bool(self.extension_names) def size(self): if self.count: @@ -119,15 +119,11 @@ class Field(object): class FieldSet(object): def __init__(self): self.fields = OrderedDict() - self.extension_fields = False self.len_fields = {} def add_data_field(self, field_name, type_obj, count=1, - is_extension=[], comments=[], optional=False): - if is_extension: - self.extension_fields = True - - field = Field(field_name, type_obj, extension=bool(is_extension), + extensions=[], comments=[], optional=False): + field = Field(field_name, type_obj, extensions=extensions, field_comments=comments, optional=optional) if bool(count): try: @@ -248,9 +244,10 @@ class Type(FieldSet): return name, False def add_data_field(self, field_name, type_obj, count=1, - is_extension=[], comments=[], optional=False): + extensions=[], comments=[], optional=False): FieldSet.add_data_field(self, field_name, type_obj, count, - is_extension, comments=comments, optional=optional) + extensions=extensions, + comments=comments, optional=optional) if type_obj.name not in self.depends_on: self.depends_on[type_obj.name] = type_obj @@ -459,7 +456,13 @@ class Master(object): subtypes = self.get_ordered_subtypes() stuff['structs'] = subtypes + self.tlv_messages() stuff['tlvs'] = self.tlvs - stuff['messages'] = list(self.messages.values()) + list(self.extension_msgs.values()) + + # We leave out extension messages in the printing pages. Any extension + # fields will get printed under the 'original' message, if present + if options.print_wire: + stuff['messages'] = list(self.messages.values()) + else: + stuff['messages'] = list(self.messages.values()) + list(self.extension_msgs.values()) stuff['subtypes'] = subtypes print(template.render(**stuff), file=output) @@ -550,6 +553,9 @@ def main(options, args=None, output=sys.stdout, lines=None): # we'll refer to 'optional' message fields as 'extensions') # if bool(tokens[5:]): # is an extension field + if optional: + raise ValueError("Extension fields cannot be optional. {}:{}" + .format(ln, line)) extension_name = "{}_{}".format(tokens[1], tokens[5]) orig_msg = msg msg = master.find_message(extension_name) @@ -558,6 +564,13 @@ def main(options, args=None, output=sys.stdout, lines=None): msg.enumname = msg.name msg.name = extension_name master.add_extension_msg(msg.name, msg) + # If this is a print_wire page, add the extension fields to the + # original message, so we can print them if present. + if options.print_wire: + orig_msg.add_data_field(tokens[2], type_obj, count=count, + extensions=tokens[5:], + comments=list(comment_set), + optional=optional) if collapse: count = 1 diff --git a/tools/test/test_cases b/tools/test/test_cases index 212900e34..dae68bb14 100644 --- a/tools/test/test_cases +++ b/tools/test/test_cases @@ -43,6 +43,7 @@ msgdata,test_msg,test_sbt_varlen_varsize,subtype_varlen_varsize, msgdata,test_msg,test_sbt_arrays,subtype_arrays, # test extension fields msgdata,test_msg,extension_1,test_features,,option_short_id +msgdata,test_msg,extension_2,test_short_id,,option_one,option_two msgtype,test_tlv1,2 msgdata,test_tlv1,test_struct,test_short_id,