From aba4e161ce6160c4f0e4f9a2484b19988b5587ec Mon Sep 17 00:00:00 2001 From: lisa neigut Date: Thu, 28 Mar 2019 14:25:19 -0700 Subject: [PATCH] tlv: calculate sizeof by measuring message length much better than statically calculating the sizeof --- tools/generate-wire.py | 110 +++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 4bfbfed46..a0744328a 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -235,15 +235,12 @@ fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p{args}) fromwire_tlv_impl_templ = """static bool _fromwire_{tlv_name}_{name}({ctx}{args}) {{ -\tsize_t start_len, plen; +\tsize_t start_len = *plen; {fields} -\tconst u8 *cursor = p; -\tplen = tal_count(p); -\tif (plen < len) +\tif (start_len < len) \t\treturn false; -\tstart_len = plen; {subcalls} -\treturn cursor != NULL && (start_len - plen == len); +\treturn cursor != NULL && (start_len - *plen == len); }} """ @@ -382,22 +379,23 @@ class Message(object): self.has_variable_fields = True self.fields.append(field) - def print_fromwire_array(self, ctx, subcalls, basetype, f, name, num_elems): + def print_fromwire_array(self, ctx, subcalls, basetype, f, name, num_elems, is_tlv=False): + p_ref = '' if is_tlv else '&' if f.has_array_helper(): - subcalls.append('fromwire_{}_array(&cursor, &plen, {}, {});' - .format(basetype, name, num_elems)) + subcalls.append('fromwire_{}_array(&cursor, {}plen, {}, {});' + .format(basetype, p_ref, name, num_elems)) else: subcalls.append('for (size_t i = 0; i < {}; i++)' .format(num_elems)) if f.fieldtype.is_assignable(): - subcalls.append('({})[i] = fromwire_{}(&cursor, &plen);' - .format(name, basetype)) + subcalls.append('({})[i] = fromwire_{}(&cursor, {}plen);' + .format(name, basetype, p_ref)) elif basetype in varlen_structs: - subcalls.append('({})[i] = fromwire_{}({}, &cursor, &plen);' - .format(name, basetype, ctx)) + subcalls.append('({})[i] = fromwire_{}({}, &cursor, {}plen);' + .format(name, basetype, ctx, p_ref)) else: - subcalls.append('fromwire_{}(&cursor, &plen, {} + i);' - .format(basetype, name)) + subcalls.append('fromwire_{}(&cursor, {}plen, {} + i);' + .format(basetype, p_ref, name)) def print_tlv_fromwire(self, tlv_name): """ prints fromwire function definition for a TLV message. @@ -405,7 +403,7 @@ class Message(object): to populate, instead of fields, as well as a length to read in """ ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else '' - args = 'const void *p, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name) + args = 'const u8 *cursor, size_t *plen, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name) fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] subcalls = CCode() for f in self.fields: @@ -419,12 +417,12 @@ class Message(object): subcalls.append('/*{} */'.format(c)) if f.is_padding(): - subcalls.append('fromwire_pad(&cursor, &plen, {});' + subcalls.append('fromwire_pad(&cursor, plen, {});' .format(f.num_elems)) elif f.is_array(): name = '*{}->{}'.format(self.name, f.name) self.print_fromwire_array('ctx', subcalls, basetype, f, name, - f.num_elems) + f.num_elems, is_tlv=True) elif f.is_variable_size(): subcalls.append("// 2nd case {name}".format(name=f.name)) typename = f.fieldtype.name @@ -437,16 +435,16 @@ class Message(object): name = '{}->{}'.format(self.name, f.name) # Allocate these off the array itself, if they need alloc. self.print_fromwire_array('*' + f.name, subcalls, basetype, f, - name, f.lenvar) + name, f.lenvar, is_tlv=True) else: if f.is_assignable(): if f.is_len_var: - s = '{} = fromwire_{}(&cursor, &plen);'.format(f.name, basetype) + s = '{} = fromwire_{}(&cursor, plen);'.format(f.name, basetype) else: - s = '{}->{} = fromwire_{}(&cursor, &plen);'.format( + s = '{}->{} = fromwire_{}(&cursor, plen);'.format( self.name, f.name, basetype) else: - s = 'fromwire_{}(&cursor, &plen, *{}->{});'.format( + s = 'fromwire_{}(&cursor, plen, *{}->{});'.format( basetype, self.name, f.name) subcalls.append(s) @@ -506,7 +504,7 @@ class Message(object): elif f.is_tlv: if not f.is_variable_size(): raise TypeError('TLV {} not variable size'.format(f.name)) - subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &{tlv_len}, {tlv_name}))' + subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &plen, &{tlv_len}, {tlv_name}))' .format(tlv_name=f.name, tlv_len=f.lenvar)) subcalls.append('return false;') elif f.is_variable_size(): @@ -585,7 +583,7 @@ class Message(object): elif f.optional: raise TypeError("Optional fields on TLV messages not currently supported. {}->{}".format(tlv_name, f.name)) if f.is_len_var: - field_decls.append('\t{0} {1} = tal_count(&{2}->{3});'.format( + field_decls.append('\t{0} {1} = tal_count({2}->{3});'.format( f.fieldtype.name, f.name, self.name, f.lenvar_for.name )) @@ -611,6 +609,9 @@ class Message(object): field_decls='\n'.join(field_decls), subcalls=str(subcalls)) + def find_tlv_lenvar_field(self, tlv_name): + return [f for f in self.fields if f.is_len_var and f.lenvar_for.is_tlv and f.lenvar_for.name == tlv_name][0] + def print_towire(self, is_header, tlv_name): if self.is_tlv: if is_header: @@ -636,9 +637,8 @@ class Message(object): for f in self.fields: if f.is_len_var: if f.lenvar_for.is_tlv: - field_decls.append('\t{0} {1} = sizeof({2});'.format( - f.fieldtype.name, f.name, f.lenvar_for.name - )) + # used below... + field_decls.append('\t{0} {1};'.format(f.fieldtype.name, f.name)) else: field_decls.append('\t{0} {1} = tal_count({2});'.format( f.fieldtype.name, f.name, f.lenvar_for.name @@ -656,11 +656,21 @@ class Message(object): .format(f.num_elems)) elif f.is_array(): self.print_towire_array(subcalls, basetype, f, f.num_elems) + elif f.is_len_var and f.lenvar_for.is_tlv: + continue # taken care of below elif f.is_tlv: if not f.is_variable_size(): - raise TypeError('TLV {} not variable size'.format(f.name)) - subcalls.append('towire__{tlv_name}(&p, {tlv_name});'.format( - tlv_name=f.name)) + raise ValueError('TLV {} not variable size'.format(f.name)) + lenvar_field = self.find_tlv_lenvar_field(f.name) + subcalls.append('/* ~~build TLV for {} ~~*/'.format(f.name)) + subcalls.append("u8 *{tlv_name}_buffer = tal_arr(ctx, u8, 0);\n" + "towire__{tlv_name}(ctx, &{tlv_name}_buffer, {tlv_name});\n" + "{lenvar_field} = tal_count({tlv_name}_buffer);\n" + "towire_{lenvar_fieldtype}(&p, {lenvar_field});\n" + "towire_u8_array(&p, {tlv_name}_buffer, {lenvar_field});\n".format( + tlv_name=f.name, + lenvar_field=lenvar_field.name, + lenvar_fieldtype=lenvar_field.fieldtype.name)) elif f.is_variable_size(): self.print_towire_array(subcalls, basetype, f, f.lenvar) else: @@ -821,43 +831,47 @@ struct _{tlv_name} {{ """ tlv__type_impl_towire_fields = """\tif ({tlv_name}->{name}) {{ +\t\ttlv_msg = tal_arr(ctx, u8, 0); \t\ttowire_u16(p, {enum}); -\t\ttowire_u16(p, sizeof(*{tlv_name}->{name})); -\t\t_towire_{tlv_name}_{name}(p, {tlv_name}->{name}); +\t\t_towire_{tlv_name}_{name}(&tlv_msg, {tlv_name}->{name}); +\t\tmsg_len = tal_count(tlv_msg); +\t\ttowire_u16(p, msg_len); +\t\ttowire_u8_array(p, tlv_msg, msg_len); \t}} """ -tlv__type_impl_towire_template = """static void towire__{tlv_name}(u8 **p, const struct _{tlv_name} *{tlv_name}) {{ +tlv__type_impl_towire_template = """static void towire__{tlv_name}(const tal_t *ctx, u8 **p, const struct _{tlv_name} *{tlv_name}) {{ +\tu16 msg_len; +\tu8 *tlv_msg; {fields}}} """ -tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, const u16 *len, struct _{tlv_name} *{tlv_name}) {{ +tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len, struct _{tlv_name} *{tlv_name}) {{ \tu16 msg_type, msg_len; -\tconst u8 *cursor = *p; -\tsize_t plen = tal_count(p); -\tif (plen != *len) +\tif (*plen < *len) \t\treturn false; -\twhile (cursor && plen) {{ -\t\tmsg_type = fromwire_u16(&cursor, &plen); -\t\tmsg_len = fromwire_u16(&cursor, &plen); -\t\tif (plen < msg_len) {{ -\t\t\tfromwire_fail(&cursor, &plen); +\twhile (*plen) {{ +\t\tmsg_type = fromwire_u16(p, plen); +\t\tmsg_len = fromwire_u16(p, plen); +\t\tif (*plen < msg_len) {{ +\t\t\tfromwire_fail(p, plen); \t\t\tbreak; \t\t}} \t\tswitch((enum {tlv_name}_type)msg_type) {{ {cases}\t\tdefault: \t\t\t// FIXME: print a warning / message? -\t\t\tcursor += msg_len; +\t\t\t*p+= msg_len; \t\t\tplen -= msg_len; \t\t}} \t}} -\treturn cursor != NULL; +\treturn *p != NULL; }} """ case_tmpl = """\t\tcase {tlv_msg_enum}: -\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}cursor, msg_len, {tlv_name}->{tlv_msg_name})) +\t\t\t{tlv_name}->{tlv_msg_name} = tal(ctx, struct _tlv_msg_{tlv_msg_name}); +\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}*p, plen, msg_len, {tlv_name}->{tlv_msg_name})) \t\t\t\treturn false; \t\t\tbreak; """ @@ -1227,10 +1241,8 @@ else: fromwire_decls.append(m.print_fromwire(options.header, tlv_field)) if not options.header: - tlv_towires = build_tlv_towires(tlv_fields) - tlv_fromwires = build_tlv_fromwires(tlv_fields) - towire_decls += tlv_towires - fromwire_decls += tlv_fromwires + towire_decls += build_tlv_towires(tlv_fields) + fromwire_decls += build_tlv_fromwires(tlv_fields) towire_decls += [m.print_towire(options.header, '') for m in messages + messages_with_option] fromwire_decls += [m.print_fromwire(options.header, '') for m in messages + messages_with_option]