diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 4692815c1..4e5321ca4 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -16,7 +16,7 @@ class FieldType(object): def is_assignable(self): return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64' or self.name == 'bool' - # Returns typename and base size + # Returns base size @staticmethod def _typesize(typename): if typename == 'pad': @@ -42,6 +42,10 @@ class FieldType(object): elif typename == 'bool': return 1 else: + # We allow unknown structures, for extensiblity (can only happen + # if explicitly specified in csv) + if typename.startswith('struct '): + return 0 raise ValueError('Unknown typename {}'.format(typename)) class Field(object): @@ -50,6 +54,7 @@ class Field(object): self.comments = comments self.name = name.replace('-', '_') self.is_len_var = False + self.is_unknown = False self.lenvar = None # Size could be a literal number (eg. 33), or a field (eg 'len'), or @@ -71,6 +76,11 @@ class Field(object): else: self.fieldtype = FieldType(typename) + # Unknown types are assumed to have base_size: div by 0 if that's unknown. + if self.fieldtype.tsize == 0: + self.is_unknown = True + self.fieldtype.tsize = base_size + if base_size % self.fieldtype.tsize != 0: raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(base_size,self.message,self.name,self.fieldtype.tsize)) self.num_elems = int(base_size / self.fieldtype.tsize) @@ -174,6 +184,8 @@ class Message(object): if field.is_variable_size(): self.checkLenField(field) self.has_variable_fields = True + elif field.is_unknown: + self.has_variable_fields = True self.fields.append(field) def print_fromwire(self,is_header): @@ -192,7 +204,7 @@ class Message(object): continue if f.is_array(): print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') - elif f.is_variable_size(): + elif f.is_variable_size() or f.is_unknown: print(', {} **{}'.format(f.fieldtype.name, f.name), end='') else: print(', {} *{}'.format(f.fieldtype.name, f.name), end='') @@ -227,7 +239,10 @@ class Message(object): for c in f.comments: print('\t/*{} */'.format(c)) - if f.is_padding(): + if f.is_unknown: + print('\t*{} = fromwire_{}(ctx, &cursor, plen);' + .format(f.name, basetype)) + elif f.is_padding(): print('\tfromwire_pad(&cursor, plen, {});' .format(f.num_elems)) elif f.is_array(): @@ -322,8 +337,7 @@ if options.output_header: print('#ifndef LIGHTNING_{0}\n' '#define LIGHTNING_{0}\n' '#include \n' - '#include \n' - ''.format(idem)) + '#include '.format(idem)) else: print('#include <{}>\n' '#include \n' @@ -333,9 +347,15 @@ else: # Maps message names to messages messages = [] comments = [] +includes = [] # Read csv lines. Single comma is the message values, more is offset/len. for line in fileinput.input(args[2:]): + # #include gets inserted into header + if line.startswith('#include '): + includes.append(line) + continue + by_comments = line.rstrip().split('#') # Emit a comment if they included one @@ -364,6 +384,11 @@ for line in fileinput.input(args[2:]): comments=[] if options.output_header: + for i in includes: + print(i, end='') + + print('') + # Dump out enum, sorted by value order. print('enum {} {{'.format(args[1])) for m in messages: