@ -1,10 +1,10 @@
import struct
import struct
from io import BufferedIOBase , BytesIO
from io import BufferedIOBase , BytesIO
from . fundamental_types import fundamental_types , BigSizeType , split_field , try_unpack , FieldType
from . fundamental_types import fundamental_types , BigSizeType , split_field , try_unpack , FieldType , IntegerType
from . array_types import (
from . array_types import (
SizedArrayType , DynamicArrayType , LengthFieldType , EllipsisArrayType
SizedArrayType , DynamicArrayType , LengthFieldType , EllipsisArrayType
)
)
from typing import Dict , List , Optional , Tuple , Any , Union , cast
from typing import Dict , List , Optional , Tuple , Any , Union , Callable , cast
class MessageNamespace ( object ) :
class MessageNamespace ( object ) :
@ -12,7 +12,7 @@ class MessageNamespace(object):
domain , such as within a given BOLT """
domain , such as within a given BOLT """
def __init__ ( self , csv_lines : List [ str ] = [ ] ) :
def __init__ ( self , csv_lines : List [ str ] = [ ] ) :
self . subtypes : Dict [ str , SubtypeType ] = { }
self . subtypes : Dict [ str , SubtypeType ] = { }
self . fundamentaltypes : Dict [ str , Subtype Type] = { }
self . fundamentaltypes : Dict [ str , Field Type] = { }
self . tlvtypes : Dict [ str , TlvStreamType ] = { }
self . tlvtypes : Dict [ str , TlvStreamType ] = { }
self . messagetypes : Dict [ str , MessageType ] = { }
self . messagetypes : Dict [ str , MessageType ] = { }
@ -28,27 +28,35 @@ domain, such as within a given BOLT"""
for v in other . subtypes . values ( ) :
for v in other . subtypes . values ( ) :
ret . add_subtype ( v )
ret . add_subtype ( v )
ret . tlvtypes = self . tlvtypes . copy ( )
ret . tlvtypes = self . tlvtypes . copy ( )
for v in other . tlvtypes . values ( ) :
for tl v in other . tlvtypes . values ( ) :
ret . add_tlvtype ( v )
ret . add_tlvtype ( tl v)
ret . messagetypes = self . messagetypes . copy ( )
ret . messagetypes = self . messagetypes . copy ( )
for v in other . messagetypes . values ( ) :
for v in other . messagetypes . values ( ) :
ret . add_messagetype ( v )
ret . add_messagetype ( v )
return ret
return ret
def _check_unique ( self , name : str ) - > None :
""" Raise an exception if name already used """
funtype = self . get_fundamentaltype ( name )
if funtype :
raise ValueError ( ' Already have {} ' . format ( funtype ) )
subtype = self . get_subtype ( name )
if subtype :
raise ValueError ( ' Already have {} ' . format ( subtype ) )
tlvtype = self . get_tlvtype ( name )
if tlvtype :
raise ValueError ( ' Already have {} ' . format ( tlvtype ) )
def add_subtype ( self , t : ' SubtypeType ' ) - > None :
def add_subtype ( self , t : ' SubtypeType ' ) - > None :
prev = self . get_type ( t . name )
self . _check_unique ( t . name )
if prev :
raise ValueError ( ' Already have {} ' . format ( prev ) )
self . subtypes [ t . name ] = t
self . subtypes [ t . name ] = t
def add_fundamentaltype ( self , t : ' SubtypeType ' ) - > None :
def add_fundamentaltype ( self , t : FieldType ) - > None :
assert not self . get_type ( t . name )
self . _check_uniqu e( t . name )
self . fundamentaltypes [ t . name ] = t
self . fundamentaltypes [ t . name ] = t
def add_tlvtype ( self , t : ' TlvStreamType ' ) - > None :
def add_tlvtype ( self , t : ' TlvStreamType ' ) - > None :
prev = self . get_type ( t . name )
self . _check_unique ( t . name )
if prev :
raise ValueError ( ' Already have {} ' . format ( prev ) )
self . tlvtypes [ t . name ] = t
self . tlvtypes [ t . name ] = t
def add_messagetype ( self , m : ' MessageType ' ) - > None :
def add_messagetype ( self , m : ' MessageType ' ) - > None :
@ -70,7 +78,7 @@ domain, such as within a given BOLT"""
return m
return m
return None
return None
def get_fundamentaltype ( self , name : str ) - > Optional [ ' SubtypeType ' ] :
def get_fundamentaltype ( self , name : str ) - > Optional [ FieldType ] :
if name in self . fundamentaltypes :
if name in self . fundamentaltypes :
return self . fundamentaltypes [ name ]
return self . fundamentaltypes [ name ]
return None
return None
@ -85,14 +93,6 @@ domain, such as within a given BOLT"""
return self . tlvtypes [ name ]
return self . tlvtypes [ name ]
return None
return None
def get_type ( self , name : str ) - > Optional [ ' SubtypeType ' ] :
t = self . get_fundamentaltype ( name )
if t is None :
t = self . get_subtype ( name )
if t is None :
t = self . get_tlvtype ( name )
return t
def load_csv ( self , lines : List [ str ] ) - > None :
def load_csv ( self , lines : List [ str ] ) - > None :
""" Load a series of comma-separate-value lines into the namespace """
""" Load a series of comma-separate-value lines into the namespace """
vals : Dict [ str , List [ List [ str ] ] ] = { ' msgtype ' : [ ] ,
vals : Dict [ str , List [ List [ str ] ] ] = { ' msgtype ' : [ ] ,
@ -152,23 +152,22 @@ class MessageTypeField(object):
return self . full_name
return self . full_name
class SubtypeType ( object ) :
class SubtypeType ( FieldType ) :
""" This defines a ' subtype ' in BOLT-speak. It consists of fields of
""" This defines a ' subtype ' in BOLT-speak. It consists of fields of
other types . Since ' msgtype ' and ' tlvtype ' are almost identical , they
other types . Since ' msgtype ' is almost identical , it inherits from this too .
inherit from this too .
"""
"""
def __init__ ( self , name : str ) :
def __init__ ( self , name : str ) :
self . name = name
super ( ) . __init__ ( name )
self . fields : List [ FieldType ] = [ ]
self . fields : List [ MessageType Field] = [ ]
def find_field ( self , fieldname : str ) :
def find_field ( self , fieldname : str ) - > Optional [ MessageTypeField ] :
for f in self . fields :
for f in self . fields :
if f . name == fieldname :
if f . name == fieldname :
return f
return f
return None
return None
def add_field ( self , field : FieldType ) :
def add_field ( self , field : MessageTypeField ) - > None :
if self . find_field ( field . name ) :
if self . find_field ( field . name ) :
raise ValueError ( " {} : duplicate field {} " . format ( self , field ) )
raise ValueError ( " {} : duplicate field {} " . format ( self , field ) )
self . fields . append ( field )
self . fields . append ( field )
@ -192,12 +191,16 @@ inherit from this too.
. format ( parts ) )
. format ( parts ) )
return SubtypeType ( parts [ 0 ] )
return SubtypeType ( parts [ 0 ] )
def _field_from_csv ( self , namespace : MessageNamespace , parts : List [ str ] , ellipsisok = False , option : str = None ) - > MessageTypeField :
def _field_from_csv ( self , namespace : MessageNamespace , parts : List [ str ] , option : str = None ) - > MessageTypeField :
""" Takes msgdata/subtypedata after first two fields
""" Takes msgdata/subtypedata after first two fields
e . g . [ . . . ] timestamp_node_id_1 , u32 ,
e . g . [ . . . ] timestamp_node_id_1 , u32 ,
"""
"""
basetype = namespace . get_type ( parts [ 1 ] )
basetype = namespace . get_fundamentaltype ( parts [ 1 ] )
if basetype is None :
basetype = namespace . get_subtype ( parts [ 1 ] )
if basetype is None :
basetype = namespace . get_tlvtype ( parts [ 1 ] )
if basetype is None :
if basetype is None :
raise ValueError ( ' Unknown type {} ' . format ( parts [ 1 ] ) )
raise ValueError ( ' Unknown type {} ' . format ( parts [ 1 ] ) )
@ -206,7 +209,8 @@ inherit from this too.
lenfield = self . find_field ( parts [ 2 ] )
lenfield = self . find_field ( parts [ 2 ] )
if lenfield is not None :
if lenfield is not None :
# If we didn't know that field was a length, we do now!
# If we didn't know that field was a length, we do now!
if type ( lenfield . fieldtype ) is not LengthFieldType :
if not isinstance ( lenfield . fieldtype , LengthFieldType ) :
assert isinstance ( lenfield . fieldtype , IntegerType )
lenfield . fieldtype = LengthFieldType ( lenfield . fieldtype )
lenfield . fieldtype = LengthFieldType ( lenfield . fieldtype )
field = MessageTypeField ( self . name , parts [ 0 ] ,
field = MessageTypeField ( self . name , parts [ 0 ] ,
DynamicArrayType ( self ,
DynamicArrayType ( self ,
@ -215,7 +219,9 @@ inherit from this too.
lenfield ) ,
lenfield ) ,
option )
option )
lenfield . fieldtype . add_length_for ( field )
lenfield . fieldtype . add_length_for ( field )
elif ellipsisok and parts [ 2 ] == ' ... ' :
elif parts [ 2 ] == ' ... ' :
# ... is only valid for a TLV.
assert isinstance ( self , TlvMessageType )
field = MessageTypeField ( self . name , parts [ 0 ] ,
field = MessageTypeField ( self . name , parts [ 0 ] ,
EllipsisArrayType ( self ,
EllipsisArrayType ( self ,
parts [ 0 ] , basetype ) ,
parts [ 0 ] , basetype ) ,
@ -264,8 +270,10 @@ inherit from this too.
raise ValueError ( " Unknown fields specified: {} " . format ( unknown ) )
raise ValueError ( " Unknown fields specified: {} " . format ( unknown ) )
for f in defined . difference ( have ) :
for f in defined . difference ( have ) :
if not f . fieldtype . is_optional ( ) :
field = self . find_field ( f )
raise ValueError ( " Missing value for {} " . format ( f ) )
assert field
if not field . fieldtype . is_optional ( ) :
raise ValueError ( " Missing value for {} " . format ( field ) )
def val_to_str ( self , v : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > str :
def val_to_str ( self , v : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > str :
self . _raise_if_badvals ( v )
self . _raise_if_badvals ( v )
@ -273,6 +281,7 @@ inherit from this too.
sep = ' '
sep = ' '
for fname , val in v . items ( ) :
for fname , val in v . items ( ) :
field = self . find_field ( fname )
field = self . find_field ( fname )
assert field
s + = sep + fname + ' = ' + field . fieldtype . val_to_str ( val , otherfields )
s + = sep + fname + ' = ' + field . fieldtype . val_to_str ( val , otherfields )
sep = ' , '
sep = ' , '
@ -281,16 +290,19 @@ inherit from this too.
def val_to_py ( self , val : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > Dict [ str , Any ] :
def val_to_py ( self , val : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > Dict [ str , Any ] :
ret : Dict [ str , Any ] = { }
ret : Dict [ str , Any ] = { }
for k , v in val . items ( ) :
for k , v in val . items ( ) :
ret [ k ] = self . find_field ( k ) . fieldtype . val_to_py ( v , val )
field = self . find_field ( k )
assert field
ret [ k ] = field . fieldtype . val_to_py ( v , val )
return ret
return ret
def write ( self , io_out : BufferedIOBase , v : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > None :
def write ( self , io_out : BufferedIOBase , v : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > None :
self . _raise_if_badvals ( v )
self . _raise_if_badvals ( v )
for fname , val in v . items ( ) :
for fname , val in v . items ( ) :
field = self . find_field ( fname )
field = self . find_field ( fname )
assert field
field . fieldtype . write ( io_out , val , otherfields )
field . fieldtype . write ( io_out , val , otherfields )
def read ( self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ] ) - > Dict [ str , Any ] :
def read ( self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ] ) - > Optional [ Dict [ str , Any ] ] :
vals = { }
vals = { }
for field in self . fields :
for field in self . fields :
val = field . fieldtype . read ( io_in , otherfields )
val = field . fieldtype . read ( io_in , otherfields )
@ -383,25 +395,46 @@ class MessageType(SubtypeType):
messagetype . add_field ( field )
messagetype . add_field ( field )
class TlvStreamType ( SubtypeType ) :
class TlvMessageType ( MessageType ) :
""" A TlvStreamType is just a Subtype, but its fields are
""" A ' tlvtype ' in BOLT-speak """
TlvMessageTypes . In the CSV format these are created implicitly , when
a tlvtype line ( which defines a TlvMessageType within the TlvType ,
def __init__ ( self , name : str , value : str ) :
confusingly ) refers to them .
super ( ) . __init__ ( name , value )
def __str__ ( self ) :
return " tlvmsgtype- {} " . format ( self . name )
class TlvStreamType ( FieldType ) :
""" A TlvStreamType ' s fields are TlvMessageTypes. In the CSV format
these are created implicitly , when a tlvtype line ( which defines a
TlvMessageType within the TlvType , confusingly ) refers to them .
"""
"""
def __init__ ( self , name ) :
def __init__ ( self , name ) :
super ( ) . __init__ ( name )
super ( ) . __init__ ( name )
self . fields : List [ TlvMessageType ] = [ ]
def __str__ ( self ) :
def __str__ ( self ) :
return " tlvstreamtype- {} " . format ( self . name )
return " tlvstreamtype- {} " . format ( self . name )
def find_field_by_number ( self , num : int ) - > Optional [ ' TlvMessageType ' ] :
def find_field ( self , fieldname : str ) - > Optional [ TlvMessageType ] :
for f in self . fields :
if f . name == fieldname :
return f
return None
def find_field_by_number ( self , num : int ) - > Optional [ TlvMessageType ] :
for f in self . fields :
for f in self . fields :
if f . number == num :
if f . number == num :
return f
return f
return None
return None
def add_field ( self , field : TlvMessageType ) - > None :
if self . find_field ( field . name ) :
raise ValueError ( " {} : duplicate field {} " . format ( self , field ) )
self . fields . append ( field )
def is_optional ( self ) - > bool :
def is_optional ( self ) - > bool :
""" You can omit a tlvstream= altogether """
""" You can omit a tlvstream= altogether """
return True
return True
@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
raise ValueError ( " Unknown tlv field {} . {} "
raise ValueError ( " Unknown tlv field {} . {} "
. format ( tlvstream , parts [ 1 ] ) )
. format ( tlvstream , parts [ 1 ] ) )
subfield = field . _field_from_csv ( namespace , parts [ 2 : ] , ellipsisok = True )
subfield = field . _field_from_csv ( namespace , parts [ 2 : ] )
field . add_field ( subfield )
field . add_field ( subfield )
def val_from_str ( self , s : str ) - > Tuple [ Dict [ str , Any ] , str ] :
def val_from_str ( self , s : str ) - > Tuple [ Dict [ str , Any ] , str ] :
@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
def val_to_py ( self , val : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > Dict [ str , Any ] :
def val_to_py ( self , val : Dict [ str , Any ] , otherfields : Dict [ str , Any ] ) - > Dict [ str , Any ] :
ret : Dict [ str , Any ] = { }
ret : Dict [ str , Any ] = { }
for k , v in val . items ( ) :
for k , v in val . items ( ) :
ret [ k ] = self . find_field ( k ) . val_to_py ( v , val )
field = self . find_field ( k )
assert field
ret [ k ] = field . val_to_py ( v , val )
return ret
return ret
def write ( self , io_out : BufferedIOBase , v : Optional [ Dict [ str , Any ] ] , otherfields : Dict [ str , Any ] ) - > None :
def write ( self , io_out : BufferedIOBase , v : Optional [ Dict [ str , Any ] ] , otherfields : Dict [ str , Any ] ) - > None :
@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires.
# ascending order as TLV spec requires.
def write_raw_val ( iobuf , val , otherfields : Dict [ str , Any ] ) :
def write_raw_val ( iobuf : BufferedIOBase , val : Any , otherfields : Dict [ str , Any ] ) - > None :
iobuf . write ( val )
iobuf . write ( val )
def get_value ( tup ) :
def get_value ( tup ) :
""" Get value from num, fun, val tuple """
""" Get value from num, fun, val tuple """
return tup [ 0 ]
return tup [ 0 ]
ordered = [ ]
ordered : List [ Tuple [ int ,
Callable [ [ BufferedIOBase , Any , Dict [ str , Any ] ] , None ] ,
Any ] ] = [ ]
for fieldname in v :
for fieldname in v :
f = self . find_field ( fieldname )
f = self . find_field ( fieldname )
if f is None :
if f is None :
@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
for typenum , writefunc , val in ordered :
for typenum , writefunc , val in ordered :
buf = BytesIO ( )
buf = BytesIO ( )
writefunc ( buf , val , otherfields )
writefunc ( cast ( BufferedIOBase , buf ) , val , otherfields )
BigSizeType . write ( io_out , typenum )
BigSizeType . write ( io_out , typenum )
BigSizeType . write ( io_out , len ( buf . getvalue ( ) ) )
BigSizeType . write ( io_out , len ( buf . getvalue ( ) ) )
io_out . write ( buf . getvalue ( ) )
io_out . write ( buf . getvalue ( ) )
def read ( self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ] ) - > Dict [ str , Any ] :
def read ( self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ] ) - > Dict [ Union [ str , int ] , Any ] :
vals : Dict [ str , Any ] = { }
vals : Dict [ Union [ str , int ] , Any ] = { }
while True :
while True :
tlv_type = BigSizeType . read ( io_in )
tlv_type = BigSizeType . read ( io_in )
@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return " {} = {} " . format ( name , self . val_to_str ( v , { } ) )
return " {} = {} " . format ( name , self . val_to_str ( v , { } ) )
class TlvMessageType ( MessageType ) :
""" A ' tlvtype ' in BOLT-speak """
def __init__ ( self , name : str , value : str ) :
super ( ) . __init__ ( name , value )
def __str__ ( self ) :
return " tlvmsgtype- {} " . format ( self . name )
class Message ( object ) :
class Message ( object ) :
""" A particular message instance """
""" A particular message instance """
def __init__ ( self , messagetype : MessageType , * * kwargs ) :
def __init__ ( self , messagetype : MessageType , * * kwargs ) :
@ -679,7 +706,8 @@ Must not have missing fields.
""" Convert to a Python native object: dicts, lists, strings, ints """
""" Convert to a Python native object: dicts, lists, strings, ints """
ret : Dict [ str , Union [ Dict [ str , Any ] , List [ Any ] , str , int ] ] = { }
ret : Dict [ str , Union [ Dict [ str , Any ] , List [ Any ] , str , int ] ] = { }
for f , v in self . fields . items ( ) :
for f , v in self . fields . items ( ) :
fieldtype = self . messagetype . find_field ( f ) . fieldtype
field = self . messagetype . find_field ( f )
ret [ f ] = fieldtype . val_to_py ( v , self . fields )
assert field
ret [ f ] = field . fieldtype . val_to_py ( v , self . fields )
return ret
return ret