diff --git a/tools/generate-wire.py b/tools/generate-wire.py index e9fc2e548..36f74554c 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -8,31 +8,70 @@ import re Enumtype = namedtuple('Enumtype', ['name', 'value']) +class FieldType(object): + def __init__(self,name): + self.name = name + self.tsize = FieldType._typesize(name) + + def is_assignable(self): + return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64' + + # Returns typename and base size + @staticmethod + def _typesize(typename): + if typename == 'pad': + return 1 + elif typename == 'struct channel_id': + return 8 + elif typename == 'struct ipv6': + return 16 + elif typename == 'struct signature': + return 64 + elif typename == 'struct pubkey': + return 33 + elif typename == 'struct sha256': + return 32 + elif typename == 'u64': + return 8 + elif typename == 'u32': + return 4 + elif typename == 'u16': + return 2 + elif typename == 'u8': + return 1 + else: + raise ValueError('Unknown typename {}'.format(typename)) + class Field(object): - def __init__(self,message,name,size,comments): + def __init__(self,message,name,size,comments,typename=None): self.message = message self.comments = comments self.name = name.replace('-', '_') self.is_len_var = False - (self.typename, self.basesize) = Field._guess_type(message,self.name,size) + self.lenvar = None + # Size could be a literal number (eg. 33), or a field (eg 'len'), or + # a multiplier of a field (eg. num-htlc-timeouts*64). try: - if int(size) % self.basesize != 0: - raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize)) - self.num_elems = int(int(size) / self.basesize) + base_size = int(size) except ValueError: - self.num_elems = 0 # If it's a multiplicitive expression, must end in basesize. if '*' in size: - tail='*' + str(self.basesize) - if not size.endswith(tail): - raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize)) - size = size[:-len(tail)] + base_size = int(size.split('*')[1]) + self.lenvar = size.split('*')[0] else: - if self.basesize != 1: - raise ValueError('Invalid size {} for {}.{} not expressed as a multiple of {}'.format(size,self.message,self.name,self.basesize)) + base_size = 0 + self.lenvar = size + self.lenvar = self.lenvar.replace('-','_') - self.lenvar = size.replace('-','_') + if typename is None: + self.fieldtype = Field._guess_type(message,self.name,base_size) + else: + self.fieldtype = FieldType(typename) + + if base_size % self.fieldtype.tsize != 0: + raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(base_size,self.message,self.name,self.fieldtype.tsize)) + self.num_elems = int(base_size / self.fieldtype.tsize) def is_padding(self): return self.name.startswith('pad') @@ -42,68 +81,67 @@ class Field(object): return self.num_elems > 1 or self.is_padding() def is_variable_size(self): - return self.num_elems == 0 + return self.lenvar is not None def is_assignable(self): if self.is_array() or self.is_variable_size(): return False - return self.typename == 'u8' or self.typename == 'u16' or self.typename == 'u32' or self.typename == 'u64' + return self.fieldtype.is_assignable() - # Returns typename and base size + # Returns FieldType @staticmethod - def _guess_type(message, fieldname, sizestr): + def _guess_type(message, fieldname, base_size): if fieldname.startswith('pad'): - return ('pad',1) + return FieldType('pad') if fieldname.endswith('channel_id'): - return ('struct channel_id',8) + return FieldType('struct channel_id') if message == 'node_announcement' and fieldname == 'ipv6': - return ('struct ipv6',16) + return FieldType('struct ipv6') if message == 'node_announcement' and fieldname == 'alias': - return ('u8',1) + return FieldType('u8') if fieldname.endswith('features'): - return ('u8',1) - - if fieldname == 'addresses': - return ('u8', 1) + return FieldType('u8') # We translate signatures and pubkeys. if 'signature' in fieldname: - return ('struct signature',64) - - # The remainder should be fixed sizes. - if sizestr == '33': - return ('struct pubkey',33) - if sizestr == '32': - return ('struct sha256',32) - if sizestr == '8': - return ('u64',8) - if sizestr == '4': - return ('u32',4) - if sizestr == '2': - return ('u16',2) - if sizestr == '1': - return ('u8',1) + return FieldType('struct signature') # We whitelist specific things here, otherwise we'd treat everything # as a u8 array. if message == 'update_fail_htlc' and fieldname == 'reason': - return ('u8', 1) + return FieldType('u8') if message == 'update_add_htlc' and fieldname == 'onion_routing_packet': - return ('u8', 1) + return FieldType('u8') if message == 'node_announcement' and fieldname == 'alias': - return ('u8',1) + return FieldType('u8') if message == 'error' and fieldname == 'data': - return ('u8',1) + return FieldType('u8') if message == 'shutdown' and fieldname == 'scriptpubkey': - return ('u8',1) + return FieldType('u8') if message == 'node_announcement' and fieldname == 'rgb_color': - return ('u8',1) + return FieldType('u8') + if message == 'node_announcement' and fieldname == 'addresses': + return FieldType('u8') - raise ValueError('Unknown size {} for {}'.format(sizestr,fieldname)) + # The remainder should be fixed sizes. + if base_size == 33: + return FieldType('struct pubkey') + if base_size == 32: + return FieldType('struct sha256') + if base_size == 8: + return FieldType('u64') + if base_size == 4: + return FieldType('u32') + if base_size == 2: + return FieldType('u16') + if base_size == 1: + return FieldType('u8') + + raise ValueError('Unknown size {} for {}'.format(base_size,fieldname)) class Message(object): def __init__(self,name,enum,comments): @@ -116,7 +154,7 @@ class Message(object): def checkLenField(self,field): for f in self.fields: if f.name == field.lenvar: - if f.typename != 'u16': + if f.fieldtype.name != 'u16': raise ValueError('Field {} has non-u16 length variable {}' .format(field.name, field.lenvar)) @@ -151,11 +189,11 @@ class Message(object): if f.is_padding(): continue if f.is_array(): - print(', {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='') + print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') elif f.is_variable_size(): - print(', {} **{}'.format(f.typename, f.name), end='') + print(', {} **{}'.format(f.fieldtype.name, f.name), end='') else: - print(', {} *{}'.format(f.typename, f.name), end='') + print(', {} *{}'.format(f.fieldtype.name, f.name), end='') if is_header: print(');') @@ -166,7 +204,7 @@ class Message(object): for f in self.fields: if f.is_len_var: - print('\t{} {};'.format(f.typename, f.name)); + print('\t{} {};'.format(f.fieldtype.name, f.name)); print('\tconst u8 *cursor = p;\n' '\tsize_t tmp_len;\n' @@ -180,9 +218,9 @@ class Message(object): .format(self.enum.name)) for f in self.fields: - basetype=f.typename - if f.typename.startswith('struct '): - basetype=f.typename[7:] + basetype=f.fieldtype.name + if f.fieldtype.name.startswith('struct '): + basetype=f.fieldtype.name[7:] for c in f.comments: print('\t/*{} */'.format(c)) @@ -197,7 +235,7 @@ class Message(object): elif f.is_variable_size(): print("\t//2th case", f.name) print('\t*{} = tal_arr(ctx, {}, {});' - .format(f.name, f.typename, f.lenvar)) + .format(f.name, f.fieldtype.name, f.lenvar)) print('\tfromwire_{}_array(&cursor, plen, *{}, {});' .format(basetype, f.name, f.lenvar)) elif f.is_assignable(): @@ -225,11 +263,11 @@ class Message(object): if f.is_padding(): continue if f.is_array(): - print(', const {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='') + print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='') elif f.is_assignable(): - print(', {} {}'.format(f.typename, f.name), end='') + print(', {} {}'.format(f.fieldtype.name, f.name), end='') else: - print(', const {} *{}'.format(f.typename, f.name), end='') + print(', const {} *{}'.format(f.fieldtype.name, f.name), end='') if is_header: print(');') @@ -242,9 +280,9 @@ class Message(object): '\ttowire_u16(&p, {});'.format(self.enum.name)) for f in self.fields: - basetype=f.typename - if f.typename.startswith('struct '): - basetype=f.typename[7:] + basetype=f.fieldtype.name + if f.fieldtype.name.startswith('struct '): + basetype=f.fieldtype.name[7:] for c in f.comments: print('\t/*{} */'.format(c)) @@ -311,10 +349,15 @@ for line in fileinput.input(args[2:]): messages.append(Message(parts[0],Enumtype("WIRE_" + parts[0].upper(), int(parts[1],0)),comments)) comments=[] else: - # eg commit_sig,0,channel-id,8 + # eg commit_sig,0,channel-id,8 OR + # commit_sig,0,channel-id,8,u64 for m in messages: if m.name == parts[0]: - m.addField(Field(parts[0], parts[2], parts[3], comments)) + if len(parts) == 4: + m.addField(Field(parts[0], parts[2], parts[3], comments)) + else: + m.addField(Field(parts[0], parts[2], parts[3], comments, + parts[4])) break comments=[]