@ -8,31 +8,70 @@ import re
Enumtype = namedtuple ( ' Enumtype ' , [ ' name ' , ' value ' ] )
class FieldType ( object ) :
def __init__ ( self , name ) :
self . name = name
self . tsize = FieldType . _typesize ( name )
def is_assignable ( self ) :
return self . name == ' u8 ' or self . name == ' u16 ' or self . name == ' u32 ' or self . name == ' u64 '
# Returns typename and base size
@staticmethod
def _typesize ( typename ) :
if typename == ' pad ' :
return 1
elif typename == ' struct channel_id ' :
return 8
elif typename == ' struct ipv6 ' :
return 16
elif typename == ' struct signature ' :
return 64
elif typename == ' struct pubkey ' :
return 33
elif typename == ' struct sha256 ' :
return 32
elif typename == ' u64 ' :
return 8
elif typename == ' u32 ' :
return 4
elif typename == ' u16 ' :
return 2
elif typename == ' u8 ' :
return 1
else :
raise ValueError ( ' Unknown typename {} ' . format ( typename ) )
class Field ( object ) :
def __init__ ( self , message , name , size , comments ) :
def __init__ ( self , message , name , size , comments , typename = None ) :
self . message = message
self . comments = comments
self . name = name . replace ( ' - ' , ' _ ' )
self . is_len_var = False
( self . typename , self . basesize ) = Field . _guess_type ( message , self . name , size )
self . lenvar = None
# Size could be a literal number (eg. 33), or a field (eg 'len'), or
# a multiplier of a field (eg. num-htlc-timeouts*64).
try :
if int ( size ) % self . basesize != 0 :
raise ValueError ( ' Invalid size {} for {} . {} not a multiple of {} ' . format ( size , self . message , self . name , self . basesize ) )
self . num_elems = int ( int ( size ) / self . basesize )
base_size = int ( size )
except ValueError :
self . num_elems = 0
# If it's a multiplicitive expression, must end in basesize.
if ' * ' in size :
tail = ' * ' + str ( self . basesize )
if not size . endswith ( tail ) :
raise ValueError ( ' Invalid size {} for {} . {} not a multiple of {} ' . format ( size , self . message , self . name , self . basesize ) )
size = size [ : - len ( tail ) ]
base_size = int ( size . split ( ' * ' ) [ 1 ] )
self . lenvar = size . split ( ' * ' ) [ 0 ]
else :
if self . basesize != 1 :
raise ValueError ( ' Invalid size {} for {} . {} not expressed as a multiple of {} ' . format ( size , self . message , self . name , self . basesize ) )
base_size = 0
self . lenvar = size
self . lenvar = self . lenvar . replace ( ' - ' , ' _ ' )
self . lenvar = size . replace ( ' - ' , ' _ ' )
if typename is None :
self . fieldtype = Field . _guess_type ( message , self . name , base_size )
else :
self . fieldtype = FieldType ( typename )
if base_size % self . fieldtype . tsize != 0 :
raise ValueError ( ' Invalid size {} for {} . {} not a multiple of {} ' . format ( base_size , self . message , self . name , self . fieldtype . tsize ) )
self . num_elems = int ( base_size / self . fieldtype . tsize )
def is_padding ( self ) :
return self . name . startswith ( ' pad ' )
@ -42,68 +81,67 @@ class Field(object):
return self . num_elems > 1 or self . is_padding ( )
def is_variable_size ( self ) :
return self . num_elems == 0
return self . lenvar is not None
def is_assignable ( self ) :
if self . is_array ( ) or self . is_variable_size ( ) :
return False
return self . typename == ' u8 ' or self . typename == ' u16 ' or self . typename == ' u32 ' or self . typename == ' u64 '
return self . fieldtype . is_assignable ( )
# Returns typename and base siz e
# Returns FieldTyp e
@staticmethod
def _guess_type ( message , fieldname , sizestr ) :
def _guess_type ( message , fieldname , base_ size) :
if fieldname . startswith ( ' pad ' ) :
return ( ' pad ' , 1 )
return FieldType ( ' pad ' )
if fieldname . endswith ( ' channel_id ' ) :
return ( ' struct channel_id ' , 8 )
return FieldType ( ' struct channel_id ' )
if message == ' node_announcement ' and fieldname == ' ipv6 ' :
return ( ' struct ipv6 ' , 16 )
return FieldType ( ' struct ipv6 ' )
if message == ' node_announcement ' and fieldname == ' alias ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if fieldname . endswith ( ' features ' ) :
return ( ' u8 ' , 1 )
if fieldname == ' addresses ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
# We translate signatures and pubkeys.
if ' signature ' in fieldname :
return ( ' struct signature ' , 64 )
# The remainder should be fixed sizes.
if sizestr == ' 33 ' :
return ( ' struct pubkey ' , 33 )
if sizestr == ' 32 ' :
return ( ' struct sha256 ' , 32 )
if sizestr == ' 8 ' :
return ( ' u64 ' , 8 )
if sizestr == ' 4 ' :
return ( ' u32 ' , 4 )
if sizestr == ' 2 ' :
return ( ' u16 ' , 2 )
if sizestr == ' 1 ' :
return ( ' u8 ' , 1 )
return FieldType ( ' struct signature ' )
# We whitelist specific things here, otherwise we'd treat everything
# as a u8 array.
if message == ' update_fail_htlc ' and fieldname == ' reason ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if message == ' update_add_htlc ' and fieldname == ' onion_routing_packet ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if message == ' node_announcement ' and fieldname == ' alias ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if message == ' error ' and fieldname == ' data ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if message == ' shutdown ' and fieldname == ' scriptpubkey ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if message == ' node_announcement ' and fieldname == ' rgb_color ' :
return ( ' u8 ' , 1 )
return FieldType ( ' u8 ' )
if message == ' node_announcement ' and fieldname == ' addresses ' :
return FieldType ( ' u8 ' )
raise ValueError ( ' Unknown size {} for {} ' . format ( sizestr , fieldname ) )
# The remainder should be fixed sizes.
if base_size == 33 :
return FieldType ( ' struct pubkey ' )
if base_size == 32 :
return FieldType ( ' struct sha256 ' )
if base_size == 8 :
return FieldType ( ' u64 ' )
if base_size == 4 :
return FieldType ( ' u32 ' )
if base_size == 2 :
return FieldType ( ' u16 ' )
if base_size == 1 :
return FieldType ( ' u8 ' )
raise ValueError ( ' Unknown size {} for {} ' . format ( base_size , fieldname ) )
class Message ( object ) :
def __init__ ( self , name , enum , comments ) :
@ -116,7 +154,7 @@ class Message(object):
def checkLenField ( self , field ) :
for f in self . fields :
if f . name == field . lenvar :
if f . typename != ' u16 ' :
if f . field type. name != ' u16 ' :
raise ValueError ( ' Field {} has non-u16 length variable {} '
. format ( field . name , field . lenvar ) )
@ -151,11 +189,11 @@ class Message(object):
if f . is_padding ( ) :
continue
if f . is_array ( ) :
print ( ' , {} {} [ {} ] ' . format ( f . typename , f . name , f . num_elems ) , end = ' ' )
print ( ' , {} {} [ {} ] ' . format ( f . field type. name , f . name , f . num_elems ) , end = ' ' )
elif f . is_variable_size ( ) :
print ( ' , {} ** {} ' . format ( f . typename , f . name ) , end = ' ' )
print ( ' , {} ** {} ' . format ( f . field type. name , f . name ) , end = ' ' )
else :
print ( ' , {} * {} ' . format ( f . typename , f . name ) , end = ' ' )
print ( ' , {} * {} ' . format ( f . field type. name , f . name ) , end = ' ' )
if is_header :
print ( ' ); ' )
@ -166,7 +204,7 @@ class Message(object):
for f in self . fields :
if f . is_len_var :
print ( ' \t {} {} ; ' . format ( f . typename , f . name ) ) ;
print ( ' \t {} {} ; ' . format ( f . field type. name , f . name ) ) ;
print ( ' \t const u8 *cursor = p; \n '
' \t size_t tmp_len; \n '
@ -180,9 +218,9 @@ class Message(object):
. format ( self . enum . name ) )
for f in self . fields :
basetype = f . typename
if f . typename . startswith ( ' struct ' ) :
basetype = f . typename [ 7 : ]
basetype = f . field type. name
if f . field type. name . startswith ( ' struct ' ) :
basetype = f . field type. name [ 7 : ]
for c in f . comments :
print ( ' \t /* {} */ ' . format ( c ) )
@ -197,7 +235,7 @@ class Message(object):
elif f . is_variable_size ( ) :
print ( " \t //2th case " , f . name )
print ( ' \t * {} = tal_arr(ctx, {} , {} ); '
. format ( f . name , f . typename , f . lenvar ) )
. format ( f . name , f . field type. name , f . lenvar ) )
print ( ' \t fromwire_ {} _array(&cursor, plen, * {} , {} ); '
. format ( basetype , f . name , f . lenvar ) )
elif f . is_assignable ( ) :
@ -225,11 +263,11 @@ class Message(object):
if f . is_padding ( ) :
continue
if f . is_array ( ) :
print ( ' , const {} {} [ {} ] ' . format ( f . typename , f . name , f . num_elems ) , end = ' ' )
print ( ' , const {} {} [ {} ] ' . format ( f . field type. name , f . name , f . num_elems ) , end = ' ' )
elif f . is_assignable ( ) :
print ( ' , {} {} ' . format ( f . typename , f . name ) , end = ' ' )
print ( ' , {} {} ' . format ( f . field type. name , f . name ) , end = ' ' )
else :
print ( ' , const {} * {} ' . format ( f . typename , f . name ) , end = ' ' )
print ( ' , const {} * {} ' . format ( f . field type. name , f . name ) , end = ' ' )
if is_header :
print ( ' ); ' )
@ -242,9 +280,9 @@ class Message(object):
' \t towire_u16(&p, {} ); ' . format ( self . enum . name ) )
for f in self . fields :
basetype = f . typename
if f . typename . startswith ( ' struct ' ) :
basetype = f . typename [ 7 : ]
basetype = f . field type. name
if f . field type. name . startswith ( ' struct ' ) :
basetype = f . field type. name [ 7 : ]
for c in f . comments :
print ( ' \t /* {} */ ' . format ( c ) )
@ -311,10 +349,15 @@ for line in fileinput.input(args[2:]):
messages . append ( Message ( parts [ 0 ] , Enumtype ( " WIRE_ " + parts [ 0 ] . upper ( ) , int ( parts [ 1 ] , 0 ) ) , comments ) )
comments = [ ]
else :
# eg commit_sig,0,channel-id,8
# eg commit_sig,0,channel-id,8 OR
# commit_sig,0,channel-id,8,u64
for m in messages :
if m . name == parts [ 0 ] :
m . addField ( Field ( parts [ 0 ] , parts [ 2 ] , parts [ 3 ] , comments ) )
if len ( parts ) == 4 :
m . addField ( Field ( parts [ 0 ] , parts [ 2 ] , parts [ 3 ] , comments ) )
else :
m . addField ( Field ( parts [ 0 ] , parts [ 2 ] , parts [ 3 ] , comments ,
parts [ 4 ] ) )
break
comments = [ ]