@ -2,6 +2,7 @@ import os
import csv
import io
from typing import Callable , Tuple , Any , Dict , List , Sequence , Union , Optional
from collections import OrderedDict
class MalformedMsg ( Exception ) :
@ -16,12 +17,56 @@ class UnexpectedEndOfStream(MalformedMsg):
pass
def _assert_can_read_at_least_n_bytes ( fd : io . BytesIO , n : int ) - > None :
class FieldEncodingNotMinimal ( MalformedMsg ) :
pass
class UnknownMandatoryTLVRecordType ( MalformedMsg ) :
pass
def _num_remaining_bytes_to_read ( fd : io . BytesIO ) - > int :
cur_pos = fd . tell ( )
end_pos = fd . seek ( 0 , io . SEEK_END )
fd . seek ( cur_pos )
if end_pos - cur_pos < n :
raise UnexpectedEndOfStream ( f " cur_pos= { cur_pos } . end_pos= { end_pos } . wants to read: { n } " )
return end_pos - cur_pos
def _assert_can_read_at_least_n_bytes ( fd : io . BytesIO , n : int ) - > None :
nremaining = _num_remaining_bytes_to_read ( fd )
if nremaining < n :
raise UnexpectedEndOfStream ( f " wants to read { n } bytes but only { nremaining } bytes left " )
def bigsize_from_int ( i : int ) - > bytes :
assert i > = 0 , i
if i < 0xfd :
return int . to_bytes ( i , length = 1 , byteorder = " big " , signed = False )
elif i < 0x1_0000 :
return b " \xfd " + int . to_bytes ( i , length = 2 , byteorder = " big " , signed = False )
elif i < 0x1_0000_0000 :
return b " \xfe " + int . to_bytes ( i , length = 4 , byteorder = " big " , signed = False )
else :
return b " \xff " + int . to_bytes ( i , length = 8 , byteorder = " big " , signed = False )
def read_int_from_bigsize ( fd : io . BytesIO ) - > Optional [ int ] :
try :
first = fd . read ( 1 ) [ 0 ]
except IndexError :
return None # end of file
if first < 0xfd :
return first
elif first == 0xfd :
_assert_can_read_at_least_n_bytes ( fd , 2 )
return int . from_bytes ( fd . read ( 2 ) , byteorder = " big " , signed = False )
elif first == 0xfe :
_assert_can_read_at_least_n_bytes ( fd , 4 )
return int . from_bytes ( fd . read ( 4 ) , byteorder = " big " , signed = False )
elif first == 0xff :
_assert_can_read_at_least_n_bytes ( fd , 8 )
return int . from_bytes ( fd . read ( 8 ) , byteorder = " big " , signed = False )
raise Exception ( )
def _read_field ( * , fd : io . BytesIO , field_type : str , count : int ) - > Union [ bytes , int ] :
@ -32,22 +77,36 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes,
type_len = None
if field_type == ' byte ' :
type_len = 1
elif field_type == ' u16 ' :
type_len = 2
elif field_type in ( ' u16 ' , ' u32 ' , ' u64 ' ) :
if field_type == ' u16 ' :
type_len = 2
elif field_type == ' u32 ' :
type_len = 4
else :
assert field_type == ' u64 '
type_len = 8
assert count == 1 , count
_assert_can_read_at_least_n_bytes ( fd , type_len )
return int . from_bytes ( fd . read ( type_len ) , byteorder = " big " , signed = False )
elif field_type == ' u32 ' :
type_len = 4
elif field_type in ( ' tu16 ' , ' tu32 ' , ' tu64 ' ) :
if field_type == ' tu16 ' :
type_len = 2
elif field_type == ' tu32 ' :
type_len = 4
else :
assert field_type == ' tu64 '
type_len = 8
assert count == 1 , count
_assert_can_read_at_least_n_bytes ( fd , type_len )
return int . from_bytes ( fd . read ( type_len ) , byteorder = " big " , signed = False )
elif field_type == ' u64 ' :
type_len = 8
raw = fd . read ( type_len )
if len ( raw ) > 0 and raw [ 0 ] == 0x00 :
raise FieldEncodingNotMinimal ( )
return int . from_bytes ( raw , byteorder = " big " , signed = False )
elif field_type == ' varint ' :
assert count == 1 , count
_assert_can_read_at_least_n_bytes ( fd , type_len )
return int . from_bytes ( fd . read ( type_len ) , byteorder = " big " , signed = False )
# TODO tu16/tu32/tu64
val = read_int_from_bigsize ( fd )
if val is None :
raise UnexpectedEndOfStream ( )
return val
elif field_type == ' chain_hash ' :
type_len = 32
elif field_type == ' channel_id ' :
@ -82,7 +141,35 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
type_len = 4
elif field_type == ' u64 ' :
type_len = 8
# TODO tu16/tu32/tu64
elif field_type in ( ' tu16 ' , ' tu32 ' , ' tu64 ' ) :
if field_type == ' tu16 ' :
type_len = 2
elif field_type == ' tu32 ' :
type_len = 4
else :
assert field_type == ' tu64 '
type_len = 8
assert count == 1 , count
if isinstance ( value , int ) :
value = int . to_bytes ( value , length = type_len , byteorder = " big " , signed = False )
if not isinstance ( value , ( bytes , bytearray ) ) :
raise Exception ( f " can only write bytes into fd. got: { value !r} " )
while len ( value ) > 0 and value [ 0 ] == 0x00 :
value = value [ 1 : ]
nbytes_written = fd . write ( value )
if nbytes_written != len ( value ) :
raise Exception ( f " tried to write { len ( value ) } bytes, but only wrote { nbytes_written } !? " )
return
elif field_type == ' varint ' :
assert count == 1 , count
if isinstance ( value , int ) :
value = bigsize_from_int ( value )
if not isinstance ( value , ( bytes , bytearray ) ) :
raise Exception ( f " can only write bytes into fd. got: { value !r} " )
nbytes_written = fd . write ( value )
if nbytes_written != len ( value ) :
raise Exception ( f " tried to write { len ( value ) } bytes, but only wrote { nbytes_written } !? " )
return
elif field_type == ' chain_hash ' :
type_len = 32
elif field_type == ' channel_id ' :
@ -109,16 +196,55 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
raise Exception ( f " tried to write { len ( value ) } bytes, but only wrote { nbytes_written } !? " )
def _read_tlv_record ( * , fd : io . BytesIO ) - > Tuple [ int , bytes ] :
if not fd : raise Exception ( )
tlv_type = _read_field ( fd = fd , field_type = " varint " , count = 1 )
tlv_len = _read_field ( fd = fd , field_type = " varint " , count = 1 )
tlv_val = _read_field ( fd = fd , field_type = " byte " , count = tlv_len )
return tlv_type , tlv_val
def _write_tlv_record ( * , fd : io . BytesIO , tlv_type : int , tlv_val : bytes ) - > None :
if not fd : raise Exception ( )
tlv_len = len ( tlv_val )
_write_field ( fd = fd , field_type = " varint " , count = 1 , value = tlv_type )
_write_field ( fd = fd , field_type = " varint " , count = 1 , value = tlv_len )
_write_field ( fd = fd , field_type = " byte " , count = tlv_len , value = tlv_val )
def _resolve_field_count ( field_count_str : str , * , vars_dict : dict ) - > int :
if field_count_str == " " :
field_count = 1
elif field_count_str == " ... " :
raise NotImplementedError ( ) # TODO...
else :
try :
field_count = int ( field_count_str )
except ValueError :
field_count = vars_dict [ field_count_str ]
if isinstance ( field_count , ( bytes , bytearray ) ) :
field_count = int . from_bytes ( field_count , byteorder = " big " )
assert isinstance ( field_count , int )
return field_count
class LNSerializer :
def __init__ ( self ) :
# TODO msg_type could be 'int' everywhere...
self . msg_scheme_from_type = { } # type: Dict[bytes, List[Sequence[str]]]
self . msg_type_from_name = { } # type: Dict[str, bytes]
self . in_tlv_stream_get_tlv_record_scheme_from_type = { } # type: Dict[str, Dict[int, List[Sequence[str]]]]
self . in_tlv_stream_get_record_type_from_name = { } # type: Dict[str, Dict[str, int]]
self . in_tlv_stream_get_record_name_from_type = { } # type: Dict[str, Dict[int, str]]
path = os . path . join ( os . path . dirname ( __file__ ) , " lnwire " , " peer_wire.csv " )
with open ( path , newline = ' ' ) as f :
csvreader = csv . reader ( f )
for row in csvreader :
#print(f">>> {row!r}")
if row [ 0 ] == " msgtype " :
# msgtype,<msgname>,<value>[,<option>]
msg_type_name = row [ 1 ]
msg_type_int = int ( row [ 2 ] )
msg_type_bytes = msg_type_int . to_bytes ( 2 , ' big ' )
@ -128,11 +254,106 @@ class LNSerializer:
self . msg_scheme_from_type [ msg_type_bytes ] = [ tuple ( row ) ]
self . msg_type_from_name [ msg_type_name ] = msg_type_bytes
elif row [ 0 ] == " msgdata " :
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
assert msg_type_name == row [ 1 ]
self . msg_scheme_from_type [ msg_type_bytes ] . append ( tuple ( row ) )
elif row [ 0 ] == " tlvtype " :
# tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
tlv_stream_name = row [ 1 ]
tlv_record_name = row [ 2 ]
tlv_record_type = int ( row [ 3 ] )
row [ 3 ] = tlv_record_type
if tlv_stream_name not in self . in_tlv_stream_get_tlv_record_scheme_from_type :
self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ] = OrderedDict ( )
self . in_tlv_stream_get_record_type_from_name [ tlv_stream_name ] = { }
self . in_tlv_stream_get_record_name_from_type [ tlv_stream_name ] = { }
assert tlv_record_type not in self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ] , f " type collision? for { tlv_stream_name } / { tlv_record_name } "
assert tlv_record_name not in self . in_tlv_stream_get_record_type_from_name [ tlv_stream_name ] , f " type collision? for { tlv_stream_name } / { tlv_record_name } "
assert tlv_record_type not in self . in_tlv_stream_get_record_type_from_name [ tlv_stream_name ] , f " type collision? for { tlv_stream_name } / { tlv_record_name } "
self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ] [ tlv_record_type ] = [ tuple ( row ) ]
self . in_tlv_stream_get_record_type_from_name [ tlv_stream_name ] [ tlv_record_name ] = tlv_record_type
self . in_tlv_stream_get_record_name_from_type [ tlv_stream_name ] [ tlv_record_type ] = tlv_record_name
if max ( self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ] . keys ( ) ) > tlv_record_type :
raise Exception ( f " tlv record types must be listed in monotonically increasing order for stream. "
f " stream= { tlv_stream_name } " )
elif row [ 0 ] == " tlvdata " :
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row [ 1 ]
assert tlv_record_name == row [ 2 ]
self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ] [ tlv_record_type ] . append ( tuple ( row ) )
else :
pass # TODO
def write_tlv_stream ( self , * , fd : io . BytesIO , tlv_stream_name : str , * * kwargs ) - > None :
scheme_map = self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ]
for tlv_record_type , scheme in scheme_map . items ( ) : # note: tlv_record_type is monotonically increasing
tlv_record_name = self . in_tlv_stream_get_record_name_from_type [ tlv_stream_name ] [ tlv_record_type ]
if tlv_record_name not in kwargs :
continue
with io . BytesIO ( ) as tlv_record_fd :
for row in scheme :
if row [ 0 ] == " tlvtype " :
pass
elif row [ 0 ] == " tlvdata " :
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row [ 1 ]
assert tlv_record_name == row [ 2 ]
field_name = row [ 3 ]
field_type = row [ 4 ]
field_count_str = row [ 5 ]
field_count = _resolve_field_count ( field_count_str , vars_dict = kwargs [ tlv_record_name ] )
field_value = kwargs [ tlv_record_name ] [ field_name ]
_write_field ( fd = tlv_record_fd ,
field_type = field_type ,
count = field_count ,
value = field_value )
else :
pass # TODO
_write_tlv_record ( fd = fd , tlv_type = tlv_record_type , tlv_val = tlv_record_fd . getvalue ( ) )
def read_tlv_stream ( self , * , fd : io . BytesIO , tlv_stream_name : str ) - > Dict [ str , Dict [ str , Any ] ] :
parsed = { } # type: Dict[str, Dict[str, Any]]
scheme_map = self . in_tlv_stream_get_tlv_record_scheme_from_type [ tlv_stream_name ]
last_seen_tlv_record_type = - 1 # type: int
while _num_remaining_bytes_to_read ( fd ) > 0 :
tlv_record_type , tlv_record_val = _read_tlv_record ( fd = fd )
if not ( tlv_record_type > last_seen_tlv_record_type ) :
raise MalformedMsg ( " TLV records must be monotonically increasing by type " )
last_seen_tlv_record_type = tlv_record_type
try :
scheme = scheme_map [ tlv_record_type ]
except KeyError :
if tlv_record_type % 2 == 0 :
# unknown "even" type: hard fail
raise UnknownMandatoryTLVRecordType ( f " { tlv_stream_name } / { tlv_record_type } " ) from None
else :
# unknown "odd" type: skip it
continue
tlv_record_name = self . in_tlv_stream_get_record_name_from_type [ tlv_stream_name ] [ tlv_record_type ]
parsed [ tlv_record_name ] = { }
with io . BytesIO ( tlv_record_val ) as tlv_record_fd :
for row in scheme :
#print(f"row: {row!r}")
if row [ 0 ] == " tlvtype " :
pass
elif row [ 0 ] == " tlvdata " :
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
assert tlv_stream_name == row [ 1 ]
assert tlv_record_name == row [ 2 ]
field_name = row [ 3 ]
field_type = row [ 4 ]
field_count_str = row [ 5 ]
field_count = _resolve_field_count ( field_count_str , vars_dict = parsed [ tlv_record_name ] )
#print(f">> count={field_count}. parsed={parsed}")
parsed [ tlv_record_name ] [ field_name ] = _read_field ( fd = tlv_record_fd ,
field_type = field_type ,
count = field_count )
else :
pass # TODO
if _num_remaining_bytes_to_read ( tlv_record_fd ) > 0 :
raise MalformedMsg ( f " TLV record ( { tlv_stream_name } / { tlv_record_name } ) has extra trailing garbage " )
return parsed
def encode_msg ( self , msg_type : str , * * kwargs ) - > bytes :
"""
Encode kwargs into a Lightning message ( bytes )
@ -147,20 +368,12 @@ class LNSerializer:
if row [ 0 ] == " msgtype " :
pass
elif row [ 0 ] == " msgdata " :
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
field_name = row [ 2 ]
field_type = row [ 3 ]
field_count_str = row [ 4 ]
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
if field_count_str == " " :
field_count = 1
else :
try :
field_count = int ( field_count_str )
except ValueError :
field_count = kwargs [ field_count_str ]
if isinstance ( field_count , ( bytes , bytearray ) ) :
field_count = int . from_bytes ( field_count , byteorder = " big " )
assert isinstance ( field_count , int )
field_count = _resolve_field_count ( field_count_str , vars_dict = kwargs )
try :
field_value = kwargs [ field_name ]
except KeyError :
@ -205,14 +418,7 @@ class LNSerializer:
field_name = row [ 2 ]
field_type = row [ 3 ]
field_count_str = row [ 4 ]
if field_count_str == " " :
field_count = 1
else :
try :
field_count = int ( field_count_str )
except ValueError :
field_count = parsed [ field_count_str ]
assert isinstance ( field_count , int )
field_count = _resolve_field_count ( field_count_str , vars_dict = parsed )
#print(f">> count={field_count}. parsed={parsed}")
try :
parsed [ field_name ] = _read_field ( fd = fd ,