diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 09d183a36..deef23433 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -8,50 +8,77 @@ import re Enumtype = namedtuple('Enumtype', ['name', 'value']) +type2size = { + 'pad': 1, + 'struct channel_id': 32, + 'struct short_channel_id': 8, + 'struct ipv6': 16, + 'secp256k1_ecdsa_signature': 64, + 'struct pubkey': 33, + 'struct sha256': 32, + 'u64': 8, + 'u32': 4, + 'u16': 2, + 'u8': 1, + 'bool': 1 +} + class FieldType(object): def __init__(self,name): self.name = name self.tsize = FieldType._typesize(name) def is_assignable(self): - return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64' or self.name == 'bool' + return self.name in ['u8', 'u16', 'u32', 'u64', 'bool'] # Returns base size @staticmethod def _typesize(typename): - if typename == 'pad': - return 1 - elif typename == 'struct short_channel_id': - return 8 - elif typename == 'struct channel_id': - return 32 - elif typename == 'struct ipv6': - return 16 - elif typename == 'secp256k1_ecdsa_signature': - return 64 - elif typename == 'struct pubkey': - return 33 - elif typename == 'struct sha256': - return 32 - elif typename == 'u64': - return 8 - elif typename == 'u32': - return 4 - elif typename == 'u16': - return 2 - elif typename == 'u8': - return 1 - elif typename == 'bool': - return 1 - else: + if typename in type2size: + return type2size[typename] + elif typename.startswith('struct '): # We allow unknown structures, for extensiblity (can only happen # if explicitly specified in csv) - if typename.startswith('struct '): - return 0 + return 0 + else: raise ValueError('Unknown typename {}'.format(typename)) +# Full (message, fieldname)-mappings +typemap = { + ('update_fail_htlc', 'reason'): FieldType('u8'), + ('node_announcement', 'alias'): FieldType('u8'), + ('update_add_htlc', 'onion_routing_packet'): FieldType('u8'), + ('error', 'data'): FieldType('u8'), + ('shutdown', 'scriptpubkey'): FieldType('u8'), + ('node_announcement', 'rgb_color'): FieldType('u8'), + ('node_announcement', 'addresses'): FieldType('u8'), + ('node_announcement', 'ipv6'): FieldType('struct ipv6'), + ('node_announcement', 'alias'): FieldType('u8'), + ('announcement_signatures', 'short_channel_id'): FieldType('struct short_channel_id'), + ('channel_announcement', 'short_channel_id'): FieldType('struct short_channel_id'), + ('channel_update', 'short_channel_id'): FieldType('struct short_channel_id') +} + +# Partial names that map to a datatype +partialtypemap = { + 'signature': FieldType('secp256k1_ecdsa_signature'), + 'features': FieldType('u8'), + 'channel_id': FieldType('struct channel_id'), + 'pad': FieldType('pad'), +} + +# Size to typename match +sizetypemap = { + 33: FieldType('struct pubkey'), + 32: FieldType('struct sha256'), + 8: FieldType('u64'), + 4: FieldType('u32'), + 2: FieldType('u16'), + 1: FieldType('u8') +} + class Field(object): - def __init__(self,message,name,size,comments,typename=None): + def __init__(self, message, name, size, comments, typename=None): self.message = message self.comments = comments self.name = name.replace('-', '_') @@ -103,61 +130,53 @@ class Field(object): # Returns FieldType @staticmethod def _guess_type(message, fieldname, base_size): - if fieldname.startswith('pad'): - return FieldType('pad') - - if fieldname.endswith('short_channel_id'): - return FieldType('struct short_channel_id') - - if fieldname.endswith('channel_id'): - return FieldType('struct channel_id') - - if message == 'node_announcement' and fieldname == 'ipv6': - return FieldType('struct ipv6') - - if message == 'node_announcement' and fieldname == 'alias': - return FieldType('u8') - - if fieldname.endswith('features'): - return FieldType('u8') - - # We translate signatures and pubkeys. - if 'signature' in fieldname: - return FieldType('secp256k1_ecdsa_signature') - - # We whitelist specific things here, otherwise we'd treat everything - # as a u8 array. - if message == 'update_fail_htlc' and fieldname == 'reason': - return FieldType('u8') - if message == 'update_add_htlc' and fieldname == 'onion_routing_packet': - return FieldType('u8') - if message == 'node_announcement' and fieldname == 'alias': - return FieldType('u8') - if message == 'error' and fieldname == 'data': - return FieldType('u8') - if message == 'shutdown' and fieldname == 'scriptpubkey': - return FieldType('u8') - if message == 'node_announcement' and fieldname == 'rgb_color': - return FieldType('u8') - if message == 'node_announcement' and fieldname == 'addresses': - return FieldType('u8') - - # The remainder should be fixed sizes. - if base_size == 33: - return FieldType('struct pubkey') - if base_size == 32: - return FieldType('struct sha256') - if base_size == 8: - return FieldType('u64') - if base_size == 4: - return FieldType('u32') - if base_size == 2: - return FieldType('u16') - if base_size == 1: - return FieldType('u8') + # Check for full (message, fieldname)-matches + if (message, fieldname) in typemap: + return typemap[(message, fieldname)] + + # Check for partial field names + for k, v in partialtypemap.items(): + if k in fieldname: + return v + + # Check for size matches + if base_size in sizetypemap: + return sizetypemap[base_size] raise ValueError('Unknown size {} for {}'.format(base_size,fieldname)) +fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p, size_t *plen{args}) +{{ +{fields} + const u8 *cursor = p; + size_t tmp_len; + + if (!plen) {{ + tmp_len = tal_count(p); + plen = &tmp_len; + }} + if (fromwire_u16(&cursor, plen) != {enum.name}) + return false; +{subcalls} + return cursor != NULL; +}} +""" + +fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p, size_t *plen{args}); +""" + +towire_header_templ = """u8 *towire_{name}(const tal_t *ctx{args}); +""" +towire_impl_templ = """u8 *towire_{name}(const tal_t *ctx{args}) +{{ +{field_decls} + u8 *p = tal_arr(ctx, u8, 0); + towire_u16(&p, {enumname}); +{subcalls} + + return memcheck(p, tal_count(p)); +}} +""" class Message(object): def __init__(self,name,enum,comments): self.name = name @@ -166,7 +185,7 @@ class Message(object): self.fields = [] self.has_variable_fields = False - def checkLenField(self,field): + def checkLenField(self, field): for f in self.fields: if f.name == field.lenvar: if f.fieldtype.name != 'u16': @@ -191,141 +210,117 @@ class Message(object): self.fields.append(field) def print_fromwire(self,is_header): - if self.has_variable_fields: - ctx_arg = 'const tal_t *ctx, ' - else: - ctx_arg = '' - - print('bool fromwire_{}({}const void *p, size_t *plen' - .format(self.name, ctx_arg), end='') + ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' + args = [] + for f in self.fields: - if f.is_len_var: + if f.is_len_var or f.is_padding(): continue - if f.is_padding(): - continue - if f.is_array(): - print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') + elif f.is_array(): + args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems)) elif f.is_variable_size(): - print(', {} **{}'.format(f.fieldtype.name, f.name), end='') + args.append(', {} **{}'.format(f.fieldtype.name, f.name)) else: - print(', {} *{}'.format(f.fieldtype.name, f.name), end='') - - if is_header: - print(');') - return + args.append(', {} *{}'.format(f.fieldtype.name, f.name)) - print(')\n' - '{') - - for f in self.fields: - if f.is_len_var: - print('\t{} {};'.format(f.fieldtype.name, f.name)); - - print('\tconst u8 *cursor = p;\n' - '\tsize_t tmp_len;\n' - '\n' - '\tif (!plen) {{\n' - '\t\ttmp_len = tal_count(p);\n' - '\t\tplen = &tmp_len;\n' - '\t}}\n' - '\tif (fromwire_u16(&cursor, plen) != {})\n' - '\t\treturn false;' - .format(self.enum.name)) + template = fromwire_header_templ if is_header else fromwire_impl_templ + fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] + subcalls = [] for f in self.fields: basetype=f.fieldtype.name if f.fieldtype.name.startswith('struct '): basetype=f.fieldtype.name[7:] for c in f.comments: - print('\t/*{} */'.format(c)) + subcalls.append('\t/*{} */'.format(c)) if f.is_padding(): - print('\tfromwire_pad(&cursor, plen, {});' - .format(f.num_elems)) + subcalls.append('\tfromwire_pad(&cursor, plen, {});' + .format(f.num_elems)) elif f.is_array(): - print("\t//1th case", f.name) - print('\tfromwire_{}_array(&cursor, plen, {}, {});' - .format(basetype, f.name, f.num_elems)) + subcalls.append("\t//1th case {name}".format(name=f.name)) + subcalls.append('\tfromwire_{}_array(&cursor, plen, {}, {});' + .format(basetype, f.name, f.num_elems)) elif f.is_variable_size(): - print("\t//2th case", f.name) - print('\t*{} = tal_arr(ctx, {}, {});' - .format(f.name, f.fieldtype.name, f.lenvar)) - print('\tfromwire_{}_array(&cursor, plen, *{}, {});' - .format(basetype, f.name, f.lenvar)) + subcalls.append("\t//2th case {name}".format(name=f.name)) + subcalls.append('\t*{} = tal_arr(ctx, {}, {});' + .format(f.name, f.fieldtype.name, f.lenvar)) + subcalls.append('\tfromwire_{}_array(&cursor, plen, *{}, {});' + .format(basetype, f.name, f.lenvar)) elif f.is_assignable(): - print("\t//3th case", f.name) + subcalls.append("\t//3th case {name}".format(name=f.name)) if f.is_len_var: - print('\t{} = fromwire_{}(&cursor, plen);' - .format(f.name, basetype)) + subcalls.append('\t{} = fromwire_{}(&cursor, plen);' + .format(f.name, basetype)) else: - print('\t*{} = fromwire_{}(&cursor, plen);' - .format(f.name, basetype)) + subcalls.append('\t*{} = fromwire_{}(&cursor, plen);' + .format(f.name, basetype)) else: - print("\t//4th case", f.name) - print('\tfromwire_{}(&cursor, plen, {});' - .format(basetype, f.name)) - - print('\n' - '\treturn cursor != NULL;\n' - '}\n') + subcalls.append("\t//4th case {name}".format(name=f.name)) + subcalls.append('\tfromwire_{}(&cursor, plen, {});' + .format(basetype, f.name)) + + return template.format( + name=self.name, + ctx=ctx_arg, + args=''.join(args), + fields=''.join(fields), + enum=self.enum, + subcalls='\n'.join(subcalls) + ) def print_towire(self,is_header): - print('u8 *towire_{}(const tal_t *ctx' - .format(self.name), end='') - + template = towire_header_templ if is_header else towire_impl_templ + args = [] for f in self.fields: if f.is_padding() or f.is_len_var: continue if f.is_array(): - print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') + args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems)) elif f.is_assignable(): - print(', {} {}'.format(f.fieldtype.name, f.name), end='') + args.append(', {} {}'.format(f.fieldtype.name, f.name)) else: - print(', const {} *{}'.format(f.fieldtype.name, f.name), end='') + args.append(', const {} *{}'.format(f.fieldtype.name, f.name)) - if is_header: - print(');') - return - - print(')\n' - '{\n') + field_decls = [] for f in self.fields: if f.is_len_var: - print('\t{0} {1} = {2} ? tal_count({2}) : 0;' - .format(f.fieldtype.name, f.name, f.lenvar_for.name)); - - print('\tu8 *p = tal_arr(ctx, u8, 0);\n' - '' - '\ttowire_u16(&p, {});'.format(self.enum.name)) + field_decls.append('\t{0} {1} = {2} ? tal_count({2}) : 0;'.format( + f.fieldtype.name, f.name, f.lenvar_for.name + )); + subcalls = [] for f in self.fields: basetype=f.fieldtype.name - if f.fieldtype.name.startswith('struct '): - basetype=f.fieldtype.name[7:] + if basetype.startswith('struct '): + basetype=basetype[7:] for c in f.comments: - print('\t/*{} */'.format(c)) + subcalls.append('\t/*{} */'.format(c)) if f.is_padding(): - print('\ttowire_pad(&p, {});' + subcalls.append('\ttowire_pad(&p, {});' .format(f.num_elems)) elif f.is_array(): - print('\ttowire_{}_array(&p, {}, {});' + subcalls.append('\ttowire_{}_array(&p, {}, {});' .format(basetype, f.name, f.num_elems)) elif f.is_variable_size(): - print('\ttowire_{}_array(&p, {}, {});' + subcalls.append('\ttowire_{}_array(&p, {}, {});' .format(basetype, f.name, f.lenvar)) else: - print('\ttowire_{}(&p, {});' + subcalls.append('\ttowire_{}(&p, {});' .format(basetype, f.name)) - # Make sure we haven't encoded any uninitialzied fields! - print('\n' - '\treturn memcheck(p, tal_count(p));\n' - '}\n') - + return template.format( + name=self.name, + args=''.join(args), + enumname=self.enum.name, + field_decls='\n'.join(field_decls), + subcalls='\n'.join(subcalls), + ) + parser = argparse.ArgumentParser(description='Generate C from from CSV') parser.add_argument('--header', action='store_true', help="Create wire header") parser.add_argument('headerfilename', help='The filename of the header') @@ -333,19 +328,6 @@ parser.add_argument('enumname', help='The name of the enum to produce') parser.add_argument('files', nargs='*', help='Files to read in (or stdin)') options = parser.parse_args() -if options.header: - idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper()) - print('#ifndef LIGHTNING_{0}\n' - '#define LIGHTNING_{0}\n' - '#include \n' - '#include '.format(idem)) -else: - print('#include <{}>\n' - '#include \n' - '#include \n' - '#include \n' - ''.format(options.headerfilename)) - # Maps message names to messages messages = [] comments = [] @@ -385,40 +367,60 @@ for line in fileinput.input(options.files): break comments=[] -if options.header: - for i in includes: - print(i, end='') - - print('') - - # Dump out enum, sorted by value order. - print('enum {} {{'.format(options.enumname)) - for m in messages: - for c in m.comments: - print('\t/*{} */'.format(c)) - print('\t{} = {},'.format(m.enum.name, m.enum.value)) - print('};') - print('const char *{}_name(int e);'.format(options.enumname)) -else: - print('const char *{}_name(int e)'.format(options.enumname)) - print('{{\n' - '\tstatic char invalidbuf[sizeof("INVALID ") + STR_MAX_CHARS(e)];\n' - '\n' - '\tswitch ((enum {})e) {{'.format(options.enumname)); - for m in messages: - print('\tcase {0}: return "{0}";'.format(m.enum.name)) - print('\t}\n' - '\n' - '\tsprintf(invalidbuf, "INVALID %i", e);\n' - '\treturn invalidbuf;\n' - '}\n' - '') +header_template = """#ifndef LIGHTNING_{idem} +#define LIGHTNING_{idem} +#include +#include +{includes} +enum {enumname} {{ +{enums}}}; +const char *{enumname}_name(int e); -for m in messages: - m.print_fromwire(options.header) +{func_decls} +#endif /* LIGHTNING_{idem} */ +""" + +impl_template = """#include <{headerfilename}> +#include +#include +#include + +const char *{enumname}_name(int e) +{{ + static char invalidbuf[sizeof("INVALID ") + STR_MAX_CHARS(e)]; + + switch ((enum {enumname})e) {{ + {cases} + }} + + sprintf(invalidbuf, "INVALID %i", e); + return invalidbuf; +}} + +{func_decls} +""" + +idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper()) +template = header_template if options.header else impl_template +# Dump out enum, sorted by value order. +enums = "" for m in messages: - m.print_towire(options.header) - -if options.header: - print('#endif /* LIGHTNING_{} */\n'.format(idem)) + for c in m.comments: + enums += '\t/*{} */\n'.format(c) + enums += '\t{} = {},\n'.format(m.enum.name, m.enum.value) +includes = '\n'.join(includes) +cases = ['case {enum.name}: return "{enum.name}";'.format(enum=m.enum) for m in messages] + +fromwire_decls = [m.print_fromwire(options.header) for m in messages] +towire_decls = [m.print_towire(options.header) for m in messages] + +print(template.format( + headerfilename=options.headerfilename, + cases='\n\t'.join(cases), + idem=idem, + includes=includes, + enumname=options.enumname, + enums=enums, + func_decls='\n'.join(fromwire_decls + towire_decls), +))