Browse Source

generate-wire.py: allow optional typename in csv file.

For our internal CSV files, we can specify the type explicitly rather
than trying to guess (eg. bool).

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
ppa-0.6.1
Rusty Russell 8 years ago
parent
commit
937a62100f
  1. 171
      tools/generate-wire.py

171
tools/generate-wire.py

@ -8,31 +8,70 @@ import re
Enumtype = namedtuple('Enumtype', ['name', 'value'])
class FieldType(object):
def __init__(self,name):
self.name = name
self.tsize = FieldType._typesize(name)
def is_assignable(self):
return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64'
# Returns typename and base size
@staticmethod
def _typesize(typename):
if typename == 'pad':
return 1
elif typename == 'struct channel_id':
return 8
elif typename == 'struct ipv6':
return 16
elif typename == 'struct signature':
return 64
elif typename == 'struct pubkey':
return 33
elif typename == 'struct sha256':
return 32
elif typename == 'u64':
return 8
elif typename == 'u32':
return 4
elif typename == 'u16':
return 2
elif typename == 'u8':
return 1
else:
raise ValueError('Unknown typename {}'.format(typename))
class Field(object):
def __init__(self,message,name,size,comments):
def __init__(self,message,name,size,comments,typename=None):
self.message = message
self.comments = comments
self.name = name.replace('-', '_')
self.is_len_var = False
(self.typename, self.basesize) = Field._guess_type(message,self.name,size)
self.lenvar = None
# Size could be a literal number (eg. 33), or a field (eg 'len'), or
# a multiplier of a field (eg. num-htlc-timeouts*64).
try:
if int(size) % self.basesize != 0:
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize))
self.num_elems = int(int(size) / self.basesize)
base_size = int(size)
except ValueError:
self.num_elems = 0
# If it's a multiplicitive expression, must end in basesize.
if '*' in size:
tail='*' + str(self.basesize)
if not size.endswith(tail):
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize))
size = size[:-len(tail)]
base_size = int(size.split('*')[1])
self.lenvar = size.split('*')[0]
else:
if self.basesize != 1:
raise ValueError('Invalid size {} for {}.{} not expressed as a multiple of {}'.format(size,self.message,self.name,self.basesize))
base_size = 0
self.lenvar = size
self.lenvar = self.lenvar.replace('-','_')
self.lenvar = size.replace('-','_')
if typename is None:
self.fieldtype = Field._guess_type(message,self.name,base_size)
else:
self.fieldtype = FieldType(typename)
if base_size % self.fieldtype.tsize != 0:
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(base_size,self.message,self.name,self.fieldtype.tsize))
self.num_elems = int(base_size / self.fieldtype.tsize)
def is_padding(self):
return self.name.startswith('pad')
@ -42,68 +81,67 @@ class Field(object):
return self.num_elems > 1 or self.is_padding()
def is_variable_size(self):
return self.num_elems == 0
return self.lenvar is not None
def is_assignable(self):
if self.is_array() or self.is_variable_size():
return False
return self.typename == 'u8' or self.typename == 'u16' or self.typename == 'u32' or self.typename == 'u64'
return self.fieldtype.is_assignable()
# Returns typename and base size
# Returns FieldType
@staticmethod
def _guess_type(message, fieldname, sizestr):
def _guess_type(message, fieldname, base_size):
if fieldname.startswith('pad'):
return ('pad',1)
return FieldType('pad')
if fieldname.endswith('channel_id'):
return ('struct channel_id',8)
return FieldType('struct channel_id')
if message == 'node_announcement' and fieldname == 'ipv6':
return ('struct ipv6',16)
return FieldType('struct ipv6')
if message == 'node_announcement' and fieldname == 'alias':
return ('u8',1)
return FieldType('u8')
if fieldname.endswith('features'):
return ('u8',1)
if fieldname == 'addresses':
return ('u8', 1)
return FieldType('u8')
# We translate signatures and pubkeys.
if 'signature' in fieldname:
return ('struct signature',64)
# The remainder should be fixed sizes.
if sizestr == '33':
return ('struct pubkey',33)
if sizestr == '32':
return ('struct sha256',32)
if sizestr == '8':
return ('u64',8)
if sizestr == '4':
return ('u32',4)
if sizestr == '2':
return ('u16',2)
if sizestr == '1':
return ('u8',1)
return FieldType('struct signature')
# We whitelist specific things here, otherwise we'd treat everything
# as a u8 array.
if message == 'update_fail_htlc' and fieldname == 'reason':
return ('u8', 1)
return FieldType('u8')
if message == 'update_add_htlc' and fieldname == 'onion_routing_packet':
return ('u8', 1)
return FieldType('u8')
if message == 'node_announcement' and fieldname == 'alias':
return ('u8',1)
return FieldType('u8')
if message == 'error' and fieldname == 'data':
return ('u8',1)
return FieldType('u8')
if message == 'shutdown' and fieldname == 'scriptpubkey':
return ('u8',1)
return FieldType('u8')
if message == 'node_announcement' and fieldname == 'rgb_color':
return ('u8',1)
return FieldType('u8')
if message == 'node_announcement' and fieldname == 'addresses':
return FieldType('u8')
raise ValueError('Unknown size {} for {}'.format(sizestr,fieldname))
# The remainder should be fixed sizes.
if base_size == 33:
return FieldType('struct pubkey')
if base_size == 32:
return FieldType('struct sha256')
if base_size == 8:
return FieldType('u64')
if base_size == 4:
return FieldType('u32')
if base_size == 2:
return FieldType('u16')
if base_size == 1:
return FieldType('u8')
raise ValueError('Unknown size {} for {}'.format(base_size,fieldname))
class Message(object):
def __init__(self,name,enum,comments):
@ -116,7 +154,7 @@ class Message(object):
def checkLenField(self,field):
for f in self.fields:
if f.name == field.lenvar:
if f.typename != 'u16':
if f.fieldtype.name != 'u16':
raise ValueError('Field {} has non-u16 length variable {}'
.format(field.name, field.lenvar))
@ -151,11 +189,11 @@ class Message(object):
if f.is_padding():
continue
if f.is_array():
print(', {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='')
print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
elif f.is_variable_size():
print(', {} **{}'.format(f.typename, f.name), end='')
print(', {} **{}'.format(f.fieldtype.name, f.name), end='')
else:
print(', {} *{}'.format(f.typename, f.name), end='')
print(', {} *{}'.format(f.fieldtype.name, f.name), end='')
if is_header:
print(');')
@ -166,7 +204,7 @@ class Message(object):
for f in self.fields:
if f.is_len_var:
print('\t{} {};'.format(f.typename, f.name));
print('\t{} {};'.format(f.fieldtype.name, f.name));
print('\tconst u8 *cursor = p;\n'
'\tsize_t tmp_len;\n'
@ -180,9 +218,9 @@ class Message(object):
.format(self.enum.name))
for f in self.fields:
basetype=f.typename
if f.typename.startswith('struct '):
basetype=f.typename[7:]
basetype=f.fieldtype.name
if f.fieldtype.name.startswith('struct '):
basetype=f.fieldtype.name[7:]
for c in f.comments:
print('\t/*{} */'.format(c))
@ -197,7 +235,7 @@ class Message(object):
elif f.is_variable_size():
print("\t//2th case", f.name)
print('\t*{} = tal_arr(ctx, {}, {});'
.format(f.name, f.typename, f.lenvar))
.format(f.name, f.fieldtype.name, f.lenvar))
print('\tfromwire_{}_array(&cursor, plen, *{}, {});'
.format(basetype, f.name, f.lenvar))
elif f.is_assignable():
@ -225,11 +263,11 @@ class Message(object):
if f.is_padding():
continue
if f.is_array():
print(', const {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='')
print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
elif f.is_assignable():
print(', {} {}'.format(f.typename, f.name), end='')
print(', {} {}'.format(f.fieldtype.name, f.name), end='')
else:
print(', const {} *{}'.format(f.typename, f.name), end='')
print(', const {} *{}'.format(f.fieldtype.name, f.name), end='')
if is_header:
print(');')
@ -242,9 +280,9 @@ class Message(object):
'\ttowire_u16(&p, {});'.format(self.enum.name))
for f in self.fields:
basetype=f.typename
if f.typename.startswith('struct '):
basetype=f.typename[7:]
basetype=f.fieldtype.name
if f.fieldtype.name.startswith('struct '):
basetype=f.fieldtype.name[7:]
for c in f.comments:
print('\t/*{} */'.format(c))
@ -311,10 +349,15 @@ for line in fileinput.input(args[2:]):
messages.append(Message(parts[0],Enumtype("WIRE_" + parts[0].upper(), int(parts[1],0)),comments))
comments=[]
else:
# eg commit_sig,0,channel-id,8
# eg commit_sig,0,channel-id,8 OR
# commit_sig,0,channel-id,8,u64
for m in messages:
if m.name == parts[0]:
m.addField(Field(parts[0], parts[2], parts[3], comments))
if len(parts) == 4:
m.addField(Field(parts[0], parts[2], parts[3], comments))
else:
m.addField(Field(parts[0], parts[2], parts[3], comments,
parts[4]))
break
comments=[]

Loading…
Cancel
Save