diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 394935293..ee0a5ce95 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -126,6 +126,16 @@ class Field(object): self.lenvar = None self.num_elems = 1 self.optional = False + self.is_tlv = False + + # field name appended with '+' means this field contains a tlv + if name.endswith('+'): + self.is_tlv = True + self.name = name[:-1] + if self.name not in tlv_fields: + # FIXME: use the rest of this + tlv_includes, tlv_messages, tlv_comments = parse_tlv_file(self.name) + tlv_fields[self.name] = tlv_messages # ? means optional field (not supported for arrays) if size.startswith('?'): @@ -630,6 +640,43 @@ def find_message_with_option(messages, optional_messages, name, option): return m +def get_directory_prefix(): + # FIXME: use prefix of filename + return "wire/" + + +def get_tlv_filename(field_name): + return 'gen_{}_csv'.format(field_name) + + +def parse_tlv_file(tlv_field_name): + tlv_includes = [] + tlv_messages = [] + tlv_comments = [] + with open(get_directory_prefix() + get_tlv_filename(tlv_field_name)) as f: + for line in f: + # #include gets inserted into header + if line.startswith('#include '): + tlv_includes.append(line) + continue + + by_comments = line.rstrip().split('#') + + # Emit a comment if they included one + if by_comments[1:]: + tlv_comments.append(' '.join(by_comments[1:])) + + parts = by_comments[0].split(',') + if parts == ['']: + continue + + if len(parts) == 2: + # eg commit_sig,132 + tlv_messages.append(Message(parts[0], Enumtype("TLV_" + parts[0].upper(), parts[1]), tlv_comments)) + tlv_comments = [] + return tlv_includes, tlv_messages, tlv_comments + + parser = argparse.ArgumentParser(description='Generate C from CSV') parser.add_argument('--header', action='store_true', help="Create wire header") parser.add_argument('--bolt', action='store_true', help="Generate wire-format for BOLT") @@ -644,6 +691,7 @@ messages = [] messages_with_option = [] comments = [] includes = [] +tlv_fields = {} prevfield = None # Read csv lines. Single comma is the message values, more is offset/len. @@ -690,6 +738,38 @@ for line in fileinput.input(options.files): prevfield = parts[2] comments = [] + +def construct_enums(msgs): + enums = "" + for m in msgs: + for c in m.comments: + enums += '\t/*{} */\n'.format(c) + enums += '\t{} = {},\n'.format(m.enum.name, m.enum.value) + return enums + + +def enum_header(enums, enumname): + return enum_header_template.format( + enums=enums, + enumname=enumname) + + +def build_enums(toplevel_enumname, toplevel_enums, tlv_fields): + enum_set = "" + enum_set += enum_header(toplevel_enums, toplevel_enumname) + for field_name, messages in tlv_fields.items(): + enum_set += "\n" + enums = construct_enums(messages) + enum_set += enum_header(enums, field_name + '_type') + return enum_set + + +enum_header_template = """enum {enumname} {{ +{enums} +}}; +const char *{enumname}_name(int e); +""" + header_template = """/* This file was generated by generate-wire.py */ /* Do not modify this file! Modify the _csv file it was generated from. */ #ifndef LIGHTNING_{idem} @@ -697,10 +777,7 @@ header_template = """/* This file was generated by generate-wire.py */ #include #include {includes} -enum {enumname} {{ -{enums}}}; -const char *{enumname}_name(int e); - +{formatted_enums} {func_decls} #endif /* LIGHTNING_{idem} */ """ @@ -773,11 +850,8 @@ else: template = impl_template # Dump out enum, sorted by value order. -enums = "" -for m in messages: - for c in m.comments: - enums += '\t/*{} */\n'.format(c) - enums += '\t{} = {},\n'.format(m.enum.name, m.enum.value) +enums = construct_enums(messages) +built_enums = build_enums(options.enumname, enums, tlv_fields) includes = '\n'.join(includes) cases = ['case {enum.name}: return "{enum.name}";'.format(enum=m.enum) for m in messages] printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{name}", msg); return;'.format(enum=m.enum, name=m.name) for m in messages] @@ -797,4 +871,5 @@ print(template.format( includes=includes, enumname=options.enumname, enums=enums, + formatted_enums=built_enums, func_decls='\n'.join(decls)))