diff --git a/tools/generate-wire.py b/tools/generate-wire.py index d849886ad..eee5d8ea4 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -232,6 +232,21 @@ fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p{args}) }} """ +fromwire_tlv_impl_templ = """static bool _fromwire_{tlv_name}_{name}({ctx}{args}) +{{ + +\tsize_t start_len, plen; +{fields} +\tconst u8 *cursor = p; +\tplen = tal_count(p); +\tif (plen < len) +\t\treturn false; +\tstart_len = plen; +{subcalls} +\treturn cursor != NULL && (start_len - plen == len); +}} +""" + fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p{args}); """ @@ -384,7 +399,72 @@ class Message(object): subcalls.append('fromwire_{}(&cursor, &plen, {} + i);' .format(basetype, name)) - def print_fromwire(self, is_header): + def print_tlv_fromwire(self, tlv_name): + """ prints fromwire function definition for a TLV message. + these are significantly different in that they take in a struct + to populate, instead of fields, as well as a length to read in + """ + ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' + args = 'const void *p, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name) + fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] + subcalls = CCode() + for f in self.fields: + basetype = f.basetype() + if f.is_tlv: + raise TypeError('Nested TLVs arent allowed!!') + elif f.optional: + raise TypeError('Optional fields on TLV messages not currently supported') + + for c in f.comments: + subcalls.append('/*{} */'.format(c)) + + if f.is_padding(): + subcalls.append('fromwire_pad(&cursor, &plen, {});' + .format(f.num_elems)) + elif f.is_array(): + name = '*{}->{}'.format(self.name, f.name) + self.print_fromwire_array('ctx', subcalls, basetype, f, name, + f.num_elems) + elif f.is_variable_size(): + subcalls.append("// 2nd case {name}".format(name=f.name)) + typename = f.fieldtype.name + # If structs are varlen, need array of ptrs to them. + if basetype in varlen_structs: + typename += ' *' + subcalls.append('{}->{} = {} ? tal_arr(ctx, {}, {}) : NULL;' + .format(self.name, f.name, f.lenvar, typename, f.lenvar)) + + name = '{}->{}'.format(self.name, f.name) + # Allocate these off the array itself, if they need alloc. + self.print_fromwire_array('*' + f.name, subcalls, basetype, f, + name, f.lenvar) + else: + if f.is_assignable(): + if f.is_len_var: + s = '{} = fromwire_{}(&cursor, &plen);'.format(f.name, basetype) + else: + s = '{}->{} = fromwire_{}(&cursor, &plen);'.format( + self.name, f.name, basetype) + else: + s = 'fromwire_{}(&cursor, &plen, *{}->{});'.format( + basetype, self.name, f.name) + subcalls.append(s) + + return fromwire_tlv_impl_templ.format( + tlv_name=tlv_name, + name=self.name, + ctx=ctx_arg, + args=''.join(args), + fields=''.join(fields), + subcalls=str(subcalls) + ) + + def print_fromwire(self, is_header, tlv_name): + if self.is_tlv: + if is_header: + return '' + return self.print_tlv_fromwire(tlv_name) + ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' args = [] @@ -394,6 +474,8 @@ class Message(object): continue elif f.is_array(): args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems)) + elif f.is_tlv: + args.append(', struct _{} *{}'.format(f.name, f.name)) else: ptrs = '*' # If we're handing a variable array, we need a ptr-to-ptr. @@ -421,6 +503,12 @@ class Message(object): elif f.is_array(): self.print_fromwire_array('ctx', subcalls, basetype, f, f.name, f.num_elems) + elif f.is_tlv: + if not f.is_variable_size(): + raise TypeError('TLV {} not variable size'.format(f.name)) + subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &{tlv_len}, {tlv_name}))' + .format(tlv_name=f.name, tlv_len=f.lenvar)) + subcalls.append('return false;') elif f.is_variable_size(): subcalls.append("//2nd case {name}".format(name=f.name)) typename = f.fieldtype.name @@ -472,21 +560,67 @@ class Message(object): subcalls=str(subcalls) ) - def print_towire_array(self, subcalls, basetype, f, num_elems): + def print_towire_array(self, subcalls, basetype, f, num_elems, is_tlv=False): + p_ref = '' if is_tlv else '&' + msg_name = self.name + '->' if is_tlv else '' if f.has_array_helper(): - subcalls.append('towire_{}_array(&p, {}, {});' - .format(basetype, f.name, num_elems)) + subcalls.append('towire_{}_array({}p, {}{}, {});' + .format(basetype, p_ref, msg_name, f.name, num_elems)) else: subcalls.append('for (size_t i = 0; i < {}; i++)' .format(num_elems)) if f.fieldtype.is_assignable() or basetype in varlen_structs: - subcalls.append('towire_{}(&p, {}[i]);' - .format(basetype, f.name)) + subcalls.append('towire_{}({}p, {}{}[i]);' + .format(basetype, p_ref, msg_name, f.name)) else: - subcalls.append('towire_{}(&p, {} + i);' - .format(basetype, f.name)) + subcalls.append('towire_{}({}p, {}{} + i);' + .format(basetype, p_ref, msg_name, f.name)) - def print_towire(self, is_header): + def print_tlv_towire(self, tlv_name): + """ prints towire function definition for a TLV message.""" + field_decls = [] + for f in self.fields: + if f.is_tlv: + raise TypeError("Nested TLVs aren't allowed!! {}->{}".format(tlv_name, f.name)) + elif f.optional: + raise TypeError("Optional fields on TLV messages not currently supported. {}->{}".format(tlv_name, f.name)) + if f.is_len_var: + field_decls.append('\t{0} {1} = tal_count(&{2}->{3});'.format( + f.fieldtype.name, f.name, self.name, f.lenvar_for.name + )) + + subcalls = CCode() + for f in self.fields: + basetype = f.fieldtype.name + if basetype.startswith('struct '): + basetype = basetype[7:] + elif basetype.startswith('enum '): + basetype = basetype[5:] + + for c in f.comments: + subcalls.append('/*{} */'.format(c)) + + if f.is_padding(): + subcalls.append('towire_pad(p, {});'.format(f.num_elems)) + elif f.is_array(): + self.print_towire_array(subcalls, basetype, f, f.num_elems, is_tlv=True) + elif f.is_variable_size(): + self.print_towire_array(subcalls, basetype, f, f.lenvar, is_tlv=True) + elif f.is_len_var: + subcalls.append('towire_{}(p, {});'.format(basetype, f.name)) + else: + subcalls.append('towire_{}(p, {}->{});'.format(basetype, self.name, f.name)) + return tlv_message_towire_stub.format( + tlv_name=tlv_name, + name=self.name, + field_decls='\n'.join(field_decls), + subcalls=str(subcalls)) + + def print_towire(self, is_header, tlv_name): + if self.is_tlv: + if is_header: + return '' + return self.print_tlv_towire(tlv_name) template = towire_header_templ if is_header else towire_impl_templ args = [] for f in self.fields: @@ -494,6 +628,8 @@ class Message(object): continue if f.is_array(): args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems)) + elif f.is_tlv: + args.append(', const struct _{} *{}'.format(f.name, f.name)) elif f.is_assignable(): args.append(', {} {}'.format(f.fieldtype.name, f.name)) elif f.is_variable_size() and f.basetype() in varlen_structs: @@ -504,9 +640,14 @@ class Message(object): field_decls = [] for f in self.fields: if f.is_len_var: - field_decls.append('\t{0} {1} = tal_count({2});'.format( - f.fieldtype.name, f.name, f.lenvar_for.name - )) + if f.lenvar_for.is_tlv: + field_decls.append('\t{0} {1} = sizeof({2});'.format( + f.fieldtype.name, f.name, f.lenvar_for.name + )) + else: + field_decls.append('\t{0} {1} = tal_count({2});'.format( + f.fieldtype.name, f.name, f.lenvar_for.name + )) subcalls = CCode() for f in self.fields: @@ -524,6 +665,11 @@ class Message(object): .format(f.num_elems)) elif f.is_array(): self.print_towire_array(subcalls, basetype, f, f.num_elems) + elif f.is_tlv: + if not f.is_variable_size(): + raise TypeError('TLV {} not variable size'.format(f.name)) + subcalls.append('towire__{tlv_name}(&p, {tlv_name});'.format( + tlv_name=f.name)) elif f.is_variable_size(): self.print_towire_array(subcalls, basetype, f, f.lenvar) else: @@ -664,6 +810,13 @@ class Message(object): fields=str(fmt_fields)) +tlv_message_towire_stub = """static void _towire_{tlv_name}_{name}(u8 **p, struct _tlv_msg_{name} *{name}) {{ +{field_decls} +{subcalls} +}} +""" + + tlv_msg_struct_template = """ struct _tlv_msg_{msg_name} {{ {fields} @@ -676,6 +829,87 @@ struct _{tlv_name} {{ }}; """ +tlv__type_impl_towire_fields = """\tif ({tlv_name}->{name}) {{ +\t\ttowire_u16(p, {enum}); +\t\ttowire_u16(p, sizeof(*{tlv_name}->{name})); +\t\t_towire_{tlv_name}_{name}(p, {tlv_name}->{name}); +\t}} +""" + +tlv__type_impl_towire_template = """static void towire__{tlv_name}(u8 **p, const struct _{tlv_name} *{tlv_name}) {{ +{fields}}} +""" + +tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, const u16 *len, struct _{tlv_name} *{tlv_name}) {{ +\tu16 msg_type, msg_len; +\tconst u8 *cursor = *p; +\tsize_t plen = tal_count(p); +\tif (plen != *len) +\t\treturn false; + +\twhile (cursor && plen) {{ +\t\tmsg_type = fromwire_u16(&cursor, &plen); +\t\tmsg_len = fromwire_u16(&cursor, &plen); +\t\tif (plen < msg_len) {{ +\t\t\tfromwire_fail(&cursor, &plen); +\t\t\tbreak; +\t\t}} +\t\tswitch((enum {tlv_name}_type)msg_type) {{ +{cases}\t\tdefault: +\t\t\t// FIXME: print a warning / message? +\t\t\tcursor += msg_len; +\t\t\tplen -= msg_len; +\t\t}} +\t}} +\treturn cursor != NULL; +}} +""" + +case_tmpl = """\t\tcase {tlv_msg_enum}: +\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}cursor, msg_len, {tlv_name}->{tlv_msg_name})) +\t\t\t\treturn false; +\t\t\tbreak; +""" + + +def build_tlv_fromwires(tlv_fields): + fromwires = [] + for field_name, messages in tlv_fields.items(): + fromwires.append(print_tlv_fromwire(field_name, messages)) + return fromwires + + +def build_tlv_towires(tlv_fields): + towires = [] + for field_name, messages in tlv_fields.items(): + towires.append(print_tlv_towire(field_name, messages)) + return towires + + +def print_tlv_towire(tlv_field_name, messages): + fields = "" + for m in messages: + fields += tlv__type_impl_towire_fields.format( + tlv_name=tlv_field_name, + enum=m.enum.name, + name=m.name) + return tlv__type_impl_towire_template.format( + tlv_name=tlv_field_name, + fields=fields) + + +def print_tlv_fromwire(tlv_field_name, messages): + cases = "" + for m in messages: + ctx_arg = 'ctx, ' if m.has_variable_fields else '' + cases += case_tmpl.format(ctx_arg=ctx_arg, + tlv_msg_enum=m.enum.name, + tlv_name=tlv_field_name, + tlv_msg_name=m.name) + return tlv__type_impl_fromwire_template.format( + tlv_name=tlv_field_name, + cases=cases) + def build_tlv_type_struct(name, messages): inner_structs = CCode() @@ -752,7 +986,6 @@ def parse_tlv_file(tlv_field_name): # eg commit_sig,132 tlv_msg = Message(parts[0], Enumtype("TLV_" + parts[0].upper(), parts[1]), tlv_comments, True) tlv_messages.append(tlv_msg) - messages.append(tlv_msg) tlv_comments = [] tlv_prevfield = None @@ -994,8 +1227,22 @@ printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{n if options.printwire: decls = [m.print_printwire(options.header) for m in messages + messages_with_option] else: - fromwire_decls = [m.print_fromwire(options.header) for m in messages + messages_with_option] - towire_decls = towire_decls = [m.print_towire(options.header) for m in messages + messages_with_option] + towire_decls = [] + fromwire_decls = [] + + for tlv_field, tlv_messages in tlv_fields.items(): + for m in tlv_messages: + towire_decls.append(m.print_towire(options.header, tlv_field)) + fromwire_decls.append(m.print_fromwire(options.header, tlv_field)) + + if not options.header: + tlv_towires = build_tlv_towires(tlv_fields) + tlv_fromwires = build_tlv_fromwires(tlv_fields) + towire_decls += tlv_towires + fromwire_decls += tlv_fromwires + + towire_decls += [m.print_towire(options.header, '') for m in messages + messages_with_option] + fromwire_decls += [m.print_fromwire(options.header, '') for m in messages + messages_with_option] decls = fromwire_decls + towire_decls print(template.format(