diff --git a/tools/generate-wire.py b/tools/generate-wire.py index c494c3dc4..8e8a6add7 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -217,6 +217,25 @@ towire_impl_templ = """u8 *towire_{name}(const tal_t *ctx{args}) return memcheck(p, tal_count(p)); }} """ + +printwire_header_templ = """void printwire_{name}(const u8 *cursor); +""" +printwire_impl_templ = """void printwire_{name}(const u8 *cursor) +{{ + size_t plen = tal_len(cursor); + + if (fromwire_u16(&cursor, &plen) != {enum.name}) {{ + printf("WRONG TYPE?!\\n"); + return; + }} + +{subcalls} + + if (plen != 0) + printf("EXTRA: %s\\n", tal_hexstr(NULL, cursor, plen)); +}} +""" + class Message(object): def __init__(self,name,enum,comments): self.name = name @@ -391,6 +410,89 @@ class Message(object): subcalls='\n'.join(subcalls), ) + def add_truncate_check(self, subcalls, indent='\t'): + # Report if truncated, otherwise print. + subcalls.append(indent + 'if (!cursor) {') + subcalls.append(indent + '\tprintf("**TRUNCATED**\\n");') + subcalls.append(indent + '\treturn;') + subcalls.append(indent + '}') + + def print_printwire_array(self, subcalls, basetype, f, num_elems): + if f.has_array_helper(): + subcalls.append('\tprintwire_{}_array(&cursor, &plen, {});' + .format(basetype, num_elems)) + else: + subcalls.append('\tprintf("[");') + subcalls.append('\tfor (size_t i = 0; i < {}; i++) {{' + .format(num_elems)) + subcalls.append('\t\t{} v;'.format(f.fieldtype.name)); + if f.fieldtype.is_assignable(): + subcalls.append('\t\tv = fromwire_{}(&cursor, plen);' + .format(name,basetype)) + else: + # We don't handle this yet! + assert not basetype in varlen_structs + + subcalls.append('\t\tfromwire_{}(&cursor, &plen, &v);' + .format(basetype)) + + self.add_truncate_check(subcalls, indent='\t\t') + + subcalls.append('\t\tprintwire_{}(&v);'.format(basetype)) + subcalls.append('\t}') + subcalls.append('\tprintf("]");') + + def print_printwire(self,is_header): + template = printwire_header_templ if is_header else printwire_impl_templ + fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] + + subcalls = [] + for f in self.fields: + basetype=f.basetype() + + for c in f.comments: + subcalls.append('\t/*{} */'.format(c)) + + if f.is_len_var: + subcalls.append('\t{} {} = fromwire_{}(&cursor, &plen);' + .format(f.fieldtype.name, f.name, basetype)) + self.add_truncate_check(subcalls) + continue + + subcalls.append('\tprintf("{}=");'.format(f.name)) + if f.is_padding(): + subcalls.append('\tprintwire_pad(&cursor, &plen, {});' + .format(f.num_elems)) + self.add_truncate_check(subcalls) + elif f.is_array(): + self.print_printwire_array(subcalls, basetype, f, f.num_elems) + self.add_truncate_check(subcalls) + elif f.is_variable_size(): + self.print_printwire_array(subcalls, basetype, f, f.lenvar) + self.add_truncate_check(subcalls) + else: + if f.is_assignable(): + subcalls.append('\t{} {} = fromwire_{}(&cursor, &plen);' + .format(f.fieldtype.name, f.name, basetype)) + else: + # Don't handle these yet. + assert not basetype in varlen_structs + subcalls.append('\t{} {};'. + format(f.fieldtype.name, f.name)); + subcalls.append('\tfromwire_{}(&cursor, &plen, &{});' + .format(basetype, f.name)) + + self.add_truncate_check(subcalls) + subcalls.append('\tprintwire_{}(&{});' + .format(basetype, f.name)) + + return template.format( + name=self.name, + fields=''.join(fields), + enum=self.enum, + subcalls='\n'.join(subcalls) + ) + def find_message(messages, name): for m in messages: if m.name == name: @@ -416,6 +518,7 @@ def find_message_with_option(messages, optional_messages, name, option): parser = argparse.ArgumentParser(description='Generate C from CSV') parser.add_argument('--header', action='store_true', help="Create wire header") parser.add_argument('--bolt', action='store_true', help="Generate wire-format for BOLT") +parser.add_argument('--printwire', action='store_true', help="Create print routines") parser.add_argument('headerfilename', help='The filename of the header') parser.add_argument('enumname', help='The name of the enum to produce') parser.add_argument('files', nargs='*', help='Files to read in (or stdin)') @@ -509,8 +612,50 @@ const char *{enumname}_name(int e) {func_decls} """ +print_header_template = """/* This file was generated by generate-wire.py */ +/* Do not modify this file! Modify the _csv file it was generated from. */ +#ifndef LIGHTNING_{idem} +#define LIGHTNING_{idem} +#include +#include +{includes} + +void print_message(const u8 *msg); + +{func_decls} +#endif /* LIGHTNING_{idem} */ +""" + +print_template = """/* This file was generated by generate-wire.py */ +/* Do not modify this file! Modify the _csv file it was generated from. */ +#include "{headerfilename}" +#include +#include +#include +#include + +void print_message(const u8 *msg) +{{ + switch ((enum {enumname})fromwire_peektype(msg)) {{ + {printcases} + }} + + printf("UNKNOWN: %s\\n", tal_hex(msg, msg)); +}} + +{func_decls} +""" + idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper()) -template = header_template if options.header else impl_template +if options.printwire: + if options.header: + template = print_header_template + else: + template = print_template +elif options.header: + template = header_template +else: + template = impl_template # Dump out enum, sorted by value order. enums = "" @@ -520,16 +665,22 @@ for m in messages: enums += '\t{} = {},\n'.format(m.enum.name, m.enum.value) includes = '\n'.join(includes) cases = ['case {enum.name}: return "{enum.name}";'.format(enum=m.enum) for m in messages] +printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}(msg); return;'.format(enum=m.enum,name=m.name) for m in messages] -fromwire_decls = [m.print_fromwire(options.header) for m in messages + messages_with_option] -towire_decls = [m.print_towire(options.header) for m in messages + messages_with_option] +if options.printwire: + decls = [m.print_printwire(options.header) for m in messages + messages_with_option] +else: + fromwire_decls = [m.print_fromwire(options.header) for m in messages + messages_with_option] + towire_decls = towire_decls = [m.print_towire(options.header) for m in messages + messages_with_option] + decls = fromwire_decls + towire_decls print(template.format( headerfilename=options.headerfilename, cases='\n\t'.join(cases), + printcases='\n\t'.join(printcases), idem=idem, includes=includes, enumname=options.enumname, enums=enums, - func_decls='\n'.join(fromwire_decls + towire_decls), + func_decls='\n'.join(decls), ))