Browse Source

tlv: add fromwire_ methods for TLV structs

pr-2587
lisa neigut 6 years ago
committed by Rusty Russell
parent
commit
6f2e70a6ac
  1. 277
      tools/generate-wire.py

277
tools/generate-wire.py

@ -232,6 +232,21 @@ fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p{args})
}}
"""
fromwire_tlv_impl_templ = """static bool _fromwire_{tlv_name}_{name}({ctx}{args})
{{
\tsize_t start_len, plen;
{fields}
\tconst u8 *cursor = p;
\tplen = tal_count(p);
\tif (plen < len)
\t\treturn false;
\tstart_len = plen;
{subcalls}
\treturn cursor != NULL && (start_len - plen == len);
}}
"""
fromwire_header_templ = """bool fromwire_{name}({ctx}const void *p{args});
"""
@ -384,7 +399,72 @@ class Message(object):
subcalls.append('fromwire_{}(&cursor, &plen, {} + i);'
.format(basetype, name))
def print_fromwire(self, is_header):
def print_tlv_fromwire(self, tlv_name):
""" prints fromwire function definition for a TLV message.
these are significantly different in that they take in a struct
to populate, instead of fields, as well as a length to read in
"""
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
args = 'const void *p, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name)
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
subcalls = CCode()
for f in self.fields:
basetype = f.basetype()
if f.is_tlv:
raise TypeError('Nested TLVs arent allowed!!')
elif f.optional:
raise TypeError('Optional fields on TLV messages not currently supported')
for c in f.comments:
subcalls.append('/*{} */'.format(c))
if f.is_padding():
subcalls.append('fromwire_pad(&cursor, &plen, {});'
.format(f.num_elems))
elif f.is_array():
name = '*{}->{}'.format(self.name, f.name)
self.print_fromwire_array('ctx', subcalls, basetype, f, name,
f.num_elems)
elif f.is_variable_size():
subcalls.append("// 2nd case {name}".format(name=f.name))
typename = f.fieldtype.name
# If structs are varlen, need array of ptrs to them.
if basetype in varlen_structs:
typename += ' *'
subcalls.append('{}->{} = {} ? tal_arr(ctx, {}, {}) : NULL;'
.format(self.name, f.name, f.lenvar, typename, f.lenvar))
name = '{}->{}'.format(self.name, f.name)
# Allocate these off the array itself, if they need alloc.
self.print_fromwire_array('*' + f.name, subcalls, basetype, f,
name, f.lenvar)
else:
if f.is_assignable():
if f.is_len_var:
s = '{} = fromwire_{}(&cursor, &plen);'.format(f.name, basetype)
else:
s = '{}->{} = fromwire_{}(&cursor, &plen);'.format(
self.name, f.name, basetype)
else:
s = 'fromwire_{}(&cursor, &plen, *{}->{});'.format(
basetype, self.name, f.name)
subcalls.append(s)
return fromwire_tlv_impl_templ.format(
tlv_name=tlv_name,
name=self.name,
ctx=ctx_arg,
args=''.join(args),
fields=''.join(fields),
subcalls=str(subcalls)
)
def print_fromwire(self, is_header, tlv_name):
if self.is_tlv:
if is_header:
return ''
return self.print_tlv_fromwire(tlv_name)
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
args = []
@ -394,6 +474,8 @@ class Message(object):
continue
elif f.is_array():
args.append(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
elif f.is_tlv:
args.append(', struct _{} *{}'.format(f.name, f.name))
else:
ptrs = '*'
# If we're handing a variable array, we need a ptr-to-ptr.
@ -421,6 +503,12 @@ class Message(object):
elif f.is_array():
self.print_fromwire_array('ctx', subcalls, basetype, f, f.name,
f.num_elems)
elif f.is_tlv:
if not f.is_variable_size():
raise TypeError('TLV {} not variable size'.format(f.name))
subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &{tlv_len}, {tlv_name}))'
.format(tlv_name=f.name, tlv_len=f.lenvar))
subcalls.append('return false;')
elif f.is_variable_size():
subcalls.append("//2nd case {name}".format(name=f.name))
typename = f.fieldtype.name
@ -472,21 +560,67 @@ class Message(object):
subcalls=str(subcalls)
)
def print_towire_array(self, subcalls, basetype, f, num_elems):
def print_towire_array(self, subcalls, basetype, f, num_elems, is_tlv=False):
p_ref = '' if is_tlv else '&'
msg_name = self.name + '->' if is_tlv else ''
if f.has_array_helper():
subcalls.append('towire_{}_array(&p, {}, {});'
.format(basetype, f.name, num_elems))
subcalls.append('towire_{}_array({}p, {}{}, {});'
.format(basetype, p_ref, msg_name, f.name, num_elems))
else:
subcalls.append('for (size_t i = 0; i < {}; i++)'
.format(num_elems))
if f.fieldtype.is_assignable() or basetype in varlen_structs:
subcalls.append('towire_{}(&p, {}[i]);'
.format(basetype, f.name))
subcalls.append('towire_{}({}p, {}{}[i]);'
.format(basetype, p_ref, msg_name, f.name))
else:
subcalls.append('towire_{}(&p, {} + i);'
.format(basetype, f.name))
subcalls.append('towire_{}({}p, {}{} + i);'
.format(basetype, p_ref, msg_name, f.name))
def print_towire(self, is_header):
def print_tlv_towire(self, tlv_name):
""" prints towire function definition for a TLV message."""
field_decls = []
for f in self.fields:
if f.is_tlv:
raise TypeError("Nested TLVs aren't allowed!! {}->{}".format(tlv_name, f.name))
elif f.optional:
raise TypeError("Optional fields on TLV messages not currently supported. {}->{}".format(tlv_name, f.name))
if f.is_len_var:
field_decls.append('\t{0} {1} = tal_count(&{2}->{3});'.format(
f.fieldtype.name, f.name, self.name, f.lenvar_for.name
))
subcalls = CCode()
for f in self.fields:
basetype = f.fieldtype.name
if basetype.startswith('struct '):
basetype = basetype[7:]
elif basetype.startswith('enum '):
basetype = basetype[5:]
for c in f.comments:
subcalls.append('/*{} */'.format(c))
if f.is_padding():
subcalls.append('towire_pad(p, {});'.format(f.num_elems))
elif f.is_array():
self.print_towire_array(subcalls, basetype, f, f.num_elems, is_tlv=True)
elif f.is_variable_size():
self.print_towire_array(subcalls, basetype, f, f.lenvar, is_tlv=True)
elif f.is_len_var:
subcalls.append('towire_{}(p, {});'.format(basetype, f.name))
else:
subcalls.append('towire_{}(p, {}->{});'.format(basetype, self.name, f.name))
return tlv_message_towire_stub.format(
tlv_name=tlv_name,
name=self.name,
field_decls='\n'.join(field_decls),
subcalls=str(subcalls))
def print_towire(self, is_header, tlv_name):
if self.is_tlv:
if is_header:
return ''
return self.print_tlv_towire(tlv_name)
template = towire_header_templ if is_header else towire_impl_templ
args = []
for f in self.fields:
@ -494,6 +628,8 @@ class Message(object):
continue
if f.is_array():
args.append(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems))
elif f.is_tlv:
args.append(', const struct _{} *{}'.format(f.name, f.name))
elif f.is_assignable():
args.append(', {} {}'.format(f.fieldtype.name, f.name))
elif f.is_variable_size() and f.basetype() in varlen_structs:
@ -504,9 +640,14 @@ class Message(object):
field_decls = []
for f in self.fields:
if f.is_len_var:
field_decls.append('\t{0} {1} = tal_count({2});'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
))
if f.lenvar_for.is_tlv:
field_decls.append('\t{0} {1} = sizeof({2});'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
))
else:
field_decls.append('\t{0} {1} = tal_count({2});'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
))
subcalls = CCode()
for f in self.fields:
@ -524,6 +665,11 @@ class Message(object):
.format(f.num_elems))
elif f.is_array():
self.print_towire_array(subcalls, basetype, f, f.num_elems)
elif f.is_tlv:
if not f.is_variable_size():
raise TypeError('TLV {} not variable size'.format(f.name))
subcalls.append('towire__{tlv_name}(&p, {tlv_name});'.format(
tlv_name=f.name))
elif f.is_variable_size():
self.print_towire_array(subcalls, basetype, f, f.lenvar)
else:
@ -664,6 +810,13 @@ class Message(object):
fields=str(fmt_fields))
tlv_message_towire_stub = """static void _towire_{tlv_name}_{name}(u8 **p, struct _tlv_msg_{name} *{name}) {{
{field_decls}
{subcalls}
}}
"""
tlv_msg_struct_template = """
struct _tlv_msg_{msg_name} {{
{fields}
@ -676,6 +829,87 @@ struct _{tlv_name} {{
}};
"""
tlv__type_impl_towire_fields = """\tif ({tlv_name}->{name}) {{
\t\ttowire_u16(p, {enum});
\t\ttowire_u16(p, sizeof(*{tlv_name}->{name}));
\t\t_towire_{tlv_name}_{name}(p, {tlv_name}->{name});
\t}}
"""
tlv__type_impl_towire_template = """static void towire__{tlv_name}(u8 **p, const struct _{tlv_name} *{tlv_name}) {{
{fields}}}
"""
tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, const u16 *len, struct _{tlv_name} *{tlv_name}) {{
\tu16 msg_type, msg_len;
\tconst u8 *cursor = *p;
\tsize_t plen = tal_count(p);
\tif (plen != *len)
\t\treturn false;
\twhile (cursor && plen) {{
\t\tmsg_type = fromwire_u16(&cursor, &plen);
\t\tmsg_len = fromwire_u16(&cursor, &plen);
\t\tif (plen < msg_len) {{
\t\t\tfromwire_fail(&cursor, &plen);
\t\t\tbreak;
\t\t}}
\t\tswitch((enum {tlv_name}_type)msg_type) {{
{cases}\t\tdefault:
\t\t\t// FIXME: print a warning / message?
\t\t\tcursor += msg_len;
\t\t\tplen -= msg_len;
\t\t}}
\t}}
\treturn cursor != NULL;
}}
"""
case_tmpl = """\t\tcase {tlv_msg_enum}:
\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}cursor, msg_len, {tlv_name}->{tlv_msg_name}))
\t\t\t\treturn false;
\t\t\tbreak;
"""
def build_tlv_fromwires(tlv_fields):
fromwires = []
for field_name, messages in tlv_fields.items():
fromwires.append(print_tlv_fromwire(field_name, messages))
return fromwires
def build_tlv_towires(tlv_fields):
towires = []
for field_name, messages in tlv_fields.items():
towires.append(print_tlv_towire(field_name, messages))
return towires
def print_tlv_towire(tlv_field_name, messages):
fields = ""
for m in messages:
fields += tlv__type_impl_towire_fields.format(
tlv_name=tlv_field_name,
enum=m.enum.name,
name=m.name)
return tlv__type_impl_towire_template.format(
tlv_name=tlv_field_name,
fields=fields)
def print_tlv_fromwire(tlv_field_name, messages):
cases = ""
for m in messages:
ctx_arg = 'ctx, ' if m.has_variable_fields else ''
cases += case_tmpl.format(ctx_arg=ctx_arg,
tlv_msg_enum=m.enum.name,
tlv_name=tlv_field_name,
tlv_msg_name=m.name)
return tlv__type_impl_fromwire_template.format(
tlv_name=tlv_field_name,
cases=cases)
def build_tlv_type_struct(name, messages):
inner_structs = CCode()
@ -752,7 +986,6 @@ def parse_tlv_file(tlv_field_name):
# eg commit_sig,132
tlv_msg = Message(parts[0], Enumtype("TLV_" + parts[0].upper(), parts[1]), tlv_comments, True)
tlv_messages.append(tlv_msg)
messages.append(tlv_msg)
tlv_comments = []
tlv_prevfield = None
@ -994,8 +1227,22 @@ printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{n
if options.printwire:
decls = [m.print_printwire(options.header) for m in messages + messages_with_option]
else:
fromwire_decls = [m.print_fromwire(options.header) for m in messages + messages_with_option]
towire_decls = towire_decls = [m.print_towire(options.header) for m in messages + messages_with_option]
towire_decls = []
fromwire_decls = []
for tlv_field, tlv_messages in tlv_fields.items():
for m in tlv_messages:
towire_decls.append(m.print_towire(options.header, tlv_field))
fromwire_decls.append(m.print_fromwire(options.header, tlv_field))
if not options.header:
tlv_towires = build_tlv_towires(tlv_fields)
tlv_fromwires = build_tlv_fromwires(tlv_fields)
towire_decls += tlv_towires
fromwire_decls += tlv_fromwires
towire_decls += [m.print_towire(options.header, '') for m in messages + messages_with_option]
fromwire_decls += [m.print_fromwire(options.header, '') for m in messages + messages_with_option]
decls = fromwire_decls + towire_decls
print(template.format(

Loading…
Cancel
Save