Browse Source

pyln: add pyln.proto.message.

This supports infrasructure for creating messages.  In particular, it
can be fed CSV from the spec's `tools/extract-formats.py` and then convert
them all to and from strings and binary formats.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
Changelog-Added: pyln: new module pyln.proto.message
nifty/pset-pre
Rusty Russell 5 years ago
committed by Christian Decker
parent
commit
eb73a0dd8f
  1. 10
      contrib/pyln-proto/pyln/proto/message/__init__.py
  2. 187
      contrib/pyln-proto/pyln/proto/message/array_types.py
  3. 231
      contrib/pyln-proto/pyln/proto/message/fundamental_types.py
  4. 611
      contrib/pyln-proto/pyln/proto/message/message.py
  5. 2
      contrib/pyln-proto/setup.py
  6. 119
      contrib/pyln-proto/tests/test_array_types.py
  7. 74
      contrib/pyln-proto/tests/test_fundamental_types.py
  8. 169
      contrib/pyln-proto/tests/test_message.py

10
contrib/pyln-proto/pyln/proto/message/__init__.py

@ -0,0 +1,10 @@
from .message import MessageNamespace, MessageType, Message, SubtypeType
__version__ = '0.0.1'
__all__ = [
"MessageNamespace",
"MessageType",
"Message",
"SubtypeType",
]

187
contrib/pyln-proto/pyln/proto/message/array_types.py

@ -0,0 +1,187 @@
#! /usr/bin/python3
from .fundamental_types import FieldType, IntegerType, split_field
class ArrayType(FieldType):
"""Abstract class for the different kinds of arrays: these are not in
the namespace, but generated when a message says it wants an array of
some type.
"""
def __init__(self, outer, name, elemtype):
super().__init__("{}.{}".format(outer.name, name))
self.elemtype = elemtype
def val_from_str(self, s):
# Simple arrays of bytes don't need commas
if self.elemtype.name == 'byte':
a, b = split_field(s)
return [b for b in bytes.fromhex(a)], b
if not s.startswith('['):
raise ValueError("array of {} must be wrapped in '[]': bad {}"
.format(self.elemtype.name, s))
s = s[1:]
ret = []
while not s.startswith(']'):
val, s = self.elemtype.val_from_str(s)
ret.append(val)
if s[0] == ',':
s = s[1:]
return ret, s[1:]
def val_to_str(self, v, otherfields):
if self.elemtype.name == 'byte':
return bytes(v).hex()
s = ''
sep = ''
for i in v:
s += sep + self.elemtype.val_to_str(i, otherfields)
sep = ','
return '[' + s + ']'
def val_to_bin(self, v, otherfields):
b = bytes()
for i in v:
b += self.elemtype.val_to_bin(i, otherfields)
return b
def arr_from_bin(self, bytestream, otherfields, arraysize):
"""arraysize None means take rest of bytestream exactly"""
totsize = 0
vals = []
i = 0
while True:
if arraysize is None and totsize == len(bytestream):
return vals, totsize
elif i == arraysize:
return vals, totsize
val, size = self.elemtype.val_from_bin(bytestream[totsize:],
otherfields)
totsize += size
i += 1
vals.append(val)
class SizedArrayType(ArrayType):
"""A fixed-size array"""
def __init__(self, outer, name, elemtype, arraysize):
super().__init__(outer, name, elemtype)
self.arraysize = arraysize
def val_to_str(self, v, otherfields):
if len(v) != self.arraysize:
raise ValueError("Length of {} != {}", v, self.arraysize)
return super().val_to_str(v, otherfields)
def val_from_str(self, s):
a, b = super().val_from_str(s)
if len(a) != self.arraysize:
raise ValueError("Length of {} != {}", s, self.arraysize)
return a, b
def val_to_bin(self, v, otherfields):
if len(v) != self.arraysize:
raise ValueError("Length of {} != {}", v, self.arraysize)
return super().val_to_bin(v, otherfields)
def val_from_bin(self, bytestream, otherfields):
return super().arr_from_bin(bytestream, otherfields, self.arraysize)
class EllipsisArrayType(ArrayType):
"""This is used for ... fields at the end of a tlv: the array ends
when the tlv ends"""
def __init__(self, tlv, name, elemtype):
super().__init__(tlv, name, elemtype)
def val_from_bin(self, bytestream, otherfields):
"""Takes rest of bytestream"""
return super().arr_from_bin(bytestream, otherfields, None)
def only_at_tlv_end(self):
"""These only make sense at the end of a TLV"""
return True
class LengthFieldType(FieldType):
"""Special type to indicate this serves as a length field for others"""
def __init__(self, inttype):
if type(inttype) is not IntegerType:
raise ValueError("{} cannot be a length; not an integer!"
.format(self.name))
super().__init__(inttype.name)
self.underlying_type = inttype
# You can be length for more than one field!
self.len_for = []
def is_optional(self):
"""This field value is always implies, never specified directly"""
return True
def add_length_for(self, field):
assert isinstance(field.fieldtype, DynamicArrayType)
self.len_for.append(field)
def calc_value(self, otherfields):
"""Calculate length value from field(s) themselves"""
if self.len_fields_bad('', otherfields):
raise ValueError("Lengths of fields {} not equal!"
.format(self.len_for))
return len(otherfields[self.len_for[0].name])
def _maybe_calc_value(self, fieldname, otherfields):
# Perhaps we're just demarshalling from binary now, so we actually
# stored it. Remove, and we'll calc from now on.
if fieldname in otherfields:
v = otherfields[fieldname]
del otherfields[fieldname]
return v
return self.calc_value(otherfields)
def val_to_bin(self, _, otherfields):
return self.underlying_type.val_to_bin(self.calc_value(otherfields),
otherfields)
def val_to_str(self, _, otherfields):
return self.underlying_type.val_to_str(self.calc_value(otherfields),
otherfields)
def name_and_val(self, name, v):
"""We don't print out length fields when printing out messages:
they're implied by the length of other fields"""
return ''
def val_from_bin(self, bytestream, otherfields):
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
return self.underlying_type.val_from_bin(bytestream, otherfields)
def val_from_str(self, s):
raise ValueError('{} is implied, cannot be specified'.format(self))
def len_fields_bad(self, fieldname, otherfields):
"""fieldname is the name to return if this length is bad"""
mylen = None
for lens in self.len_for:
if mylen is not None:
if mylen != len(otherfields[lens.name]):
return [fieldname]
# Field might be missing!
if lens.name in otherfields:
mylen = len(otherfields[lens.name])
return []
class DynamicArrayType(ArrayType):
"""This is used for arrays where another field controls the size"""
def __init__(self, outer, name, elemtype, lenfield):
super().__init__(outer, name, elemtype)
assert type(lenfield.fieldtype) is LengthFieldType
self.lenfield = lenfield
def val_from_bin(self, bytestream, otherfields):
return super().arr_from_bin(bytestream, otherfields,
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))

231
contrib/pyln-proto/pyln/proto/message/fundamental_types.py

@ -0,0 +1,231 @@
#! /usr/bin/python3
import struct
def split_field(s):
"""Helper to split string into first part and remainder"""
def len_without(s, delim):
pos = s.find(delim)
if pos == -1:
return len(s)
return pos
firstlen = min([len_without(s, d) for d in (',', '}', ']')])
return s[:firstlen], s[firstlen:]
class FieldType(object):
"""A (abstract) class representing the underlying type of a field.
These are further specialized.
"""
def __init__(self, name):
self.name = name
def only_at_tlv_end(self):
"""Some types only make sense inside a tlv, at the end"""
return False
def name_and_val(self, name, v):
"""This is overridden by LengthFieldType to return nothing"""
return " {}={}".format(name, self.val_to_str(v, None))
def is_optional(self):
"""Overridden for tlv fields and optional fields"""
return False
def len_fields_bad(self, fieldname, fieldvals):
"""Overridden by length fields for arrays"""
return []
def __str__(self):
return self.name
def __repr__(self):
return self.name
class IntegerType(FieldType):
def __init__(self, name, bytelen, structfmt):
super().__init__(name)
self.bytelen = bytelen
self.structfmt = structfmt
def val_to_str(self, v, otherfields):
return "{}".format(int(v))
def val_from_str(self, s):
a, b = split_field(s)
return int(a), b
def val_to_bin(self, v, otherfields):
return struct.pack(self.structfmt, v)
def val_from_bin(self, bytestream, otherfields):
"Returns value, bytesused"
if self.bytelen > len(bytestream):
raise ValueError('{}: not enough remaining to read'.format(self))
return struct.unpack_from(self.structfmt,
bytestream)[0], self.bytelen
class ShortChannelIDType(IntegerType):
"""short_channel_id has a special string representation, but is
basically a u64.
"""
def __init__(self, name):
super().__init__(name, 8, '>Q')
def val_to_str(self, v, otherfields):
# See BOLT #7: ## Definition of `short_channel_id`
return "{}x{}x{}".format(v >> 40, (v >> 16) & 0xFFFFFF, v & 0xFFFF)
def val_from_str(self, s):
a, b = split_field(s)
parts = a.split('x')
if len(parts) != 3:
raise ValueError("short_channel_id should be NxNxN")
return ((int(parts[0]) << 40)
| (int(parts[1]) << 16)
| (int(parts[2]))), b
class TruncatedIntType(FieldType):
"""Truncated integer types"""
def __init__(self, name, maxbytes):
super().__init__(name)
self.maxbytes = maxbytes
def val_to_str(self, v, otherfields):
return "{}".format(int(v))
def only_at_tlv_end(self):
"""These only make sense at the end of a TLV"""
return True
def val_from_str(self, s):
a, b = split_field(s)
if int(a) >= (1 << (self.maxbytes * 8)):
raise ValueError('{} exceeds maximum {} capacity'
.format(a, self.name))
return int(a), b
def val_to_bin(self, v, otherfields):
binval = struct.pack('>Q', v)
while len(binval) != 0 and binval[0] == 0:
binval = binval[1:]
if len(binval) > self.maxbytes:
raise ValueError('{} exceeds maximum {} capacity'
.format(v, self.name))
return binval
def val_from_bin(self, bytestream, otherfields):
"Returns value, bytesused"
binval = bytes()
while len(binval) < len(bytestream):
if len(binval) == 0 and bytestream[len(binval)] == 0:
raise ValueError('{} encoding is not minimal: {}'
.format(self.name, bytestream))
binval += bytes([bytestream[len(binval)]])
if len(binval) > self.maxbytes:
raise ValueError('{} is too long for {}'.format(binval, self.name))
# Pad with zeroes and convert as u64
return (struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0],
len(binval))
class FundamentalHexType(FieldType):
"""The remaining fundamental types are simply represented as hex strings"""
def __init__(self, name, bytelen):
super().__init__(name)
self.bytelen = bytelen
def val_to_str(self, v, otherfields):
if len(bytes(v)) != self.bytelen:
raise ValueError("Length of {} != {}", v, self.bytelen)
return v.hex()
def val_from_str(self, s):
a, b = split_field(s)
ret = bytes.fromhex(a)
if len(ret) != self.bytelen:
raise ValueError("Length of {} != {}", a, self.bytelen)
return ret, b
def val_to_bin(self, v, otherfields):
if len(bytes(v)) != self.bytelen:
raise ValueError("Length of {} != {}", v, self.bytelen)
return bytes(v)
def val_from_bin(self, bytestream, otherfields):
"Returns value, size from bytestream"
if self.bytelen > len(bytestream):
raise ValueError('{}: not enough remaining'.format(self))
return bytestream[:self.bytelen], self.bytelen
class BigSizeType(FieldType):
"""BigSize type, mainly used to encode TLV headers"""
def __init__(self, name):
super().__init__(name)
def val_from_str(self, s):
a, b = split_field(s)
return int(a), b
# For the convenience of TLV header parsing
@staticmethod
def to_bin(v):
if v < 253:
return bytes([v])
elif v < 2**16:
return bytes([253]) + struct.pack('>H', v)
elif v < 2**32:
return bytes([254]) + struct.pack('>I', v)
else:
return bytes([255]) + struct.pack('>Q', v)
@staticmethod
def from_bin(bytestream):
"Returns value, bytesused"
if bytestream[0] < 253:
return int(bytestream[0]), 1
elif bytestream[0] == 253:
return struct.unpack_from('>H', bytestream[1:])[0], 3
elif bytestream[0] == 254:
return struct.unpack_from('>I', bytestream[1:])[0], 5
else:
return struct.unpack_from('>Q', bytestream[1:])[0], 9
def val_to_str(self, v, otherfields):
return "{}".format(int(v))
def val_to_bin(self, v, otherfields):
return self.to_bin(v)
def val_from_bin(self, bytestream, otherfields):
return self.from_bin(bytestream)
def fundamental_types():
# From 01-messaging.md#fundamental-types:
return [IntegerType('byte', 1, 'B'),
IntegerType('u16', 2, '>H'),
IntegerType('u32', 4, '>I'),
IntegerType('u64', 8, '>Q'),
TruncatedIntType('tu16', 2),
TruncatedIntType('tu32', 4),
TruncatedIntType('tu64', 8),
FundamentalHexType('chain_hash', 32),
FundamentalHexType('channel_id', 32),
FundamentalHexType('sha256', 32),
FundamentalHexType('point', 33),
ShortChannelIDType('short_channel_id'),
FundamentalHexType('signature', 64),
BigSizeType('bigsize'),
# FIXME: See https://github.com/lightningnetwork/lightning-rfc/pull/778
BigSizeType('varint'),
]

611
contrib/pyln-proto/pyln/proto/message/message.py

@ -0,0 +1,611 @@
#! /usr/bin/python3
import struct
from .fundamental_types import fundamental_types, BigSizeType, split_field
from .array_types import (
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
)
class MessageNamespace(object):
"""A class which contains all FieldTypes and Messages in a particular
domain, such as within a given BOLT"""
def __init__(self, csv_lines=[]):
self.subtypes = {}
self.tlvtypes = {}
self.messagetypes = {}
# For convenience, basic types go in every namespace
for t in fundamental_types():
self.add_subtype(t)
self.load_csv(csv_lines)
def add_subtype(self, t):
prev = self.get_type(t.name)
if prev:
return ValueError('Already have {}'.format(prev))
self.subtypes[t.name] = t
def add_tlvtype(self, t):
prev = self.get_type(t.name)
if prev:
return ValueError('Already have {}'.format(prev))
self.tlvtypes[t.name] = t
def add_messagetype(self, m):
if self.get_msgtype(m.name):
return ValueError('{}: message already exists'.format(m.name))
if self.get_msgtype_by_number(m.number):
return ValueError('{}: message {} already number {}'.format(
m.name, self.get_msg_by_number(m.number), m.number))
self.messagetypes[m.name] = m
def get_msgtype(self, name):
if name in self.messagetypes:
return self.messagetypes[name]
return None
def get_msgtype_by_number(self, num):
for m in self.messagetypes.values():
if m.number == num:
return m
return None
def get_subtype(self, name):
if name in self.subtypes:
return self.subtypes[name]
return None
def get_tlvtype(self, name):
if name in self.tlvtypes:
return self.tlvtypes[name]
return None
def get_type(self, name):
t = self.get_subtype(name)
if not t:
t = self.get_tlvtype(name)
return t
def get_tlv_by_number(self, num):
for t in self.tlvtypes:
if t.number == num:
return t
return None
def load_csv(self, lines):
"""Load a series of comma-separate-value lines into the namespace"""
vals = {'msgtype': [],
'msgdata': [],
'tlvtype': [],
'tlvdata': [],
'subtype': [],
'subtypedata': []}
for l in lines:
parts = l.split(',')
if parts[0] not in vals:
raise ValueError("Unknown type {} in {}".format(parts[0], l))
vals[parts[0]].append(parts[1:])
# Types can refer to other types, so add data last.
for parts in vals['msgtype']:
self.add_messagetype(MessageType.type_from_csv(parts))
for parts in vals['subtype']:
self.add_subtype(SubtypeType.type_from_csv(parts))
for parts in vals['tlvtype']:
TlvStreamType.type_from_csv(self, parts)
for parts in vals['msgdata']:
MessageType.field_from_csv(self, parts)
for parts in vals['subtypedata']:
SubtypeType.field_from_csv(self, parts)
for parts in vals['tlvdata']:
TlvStreamType.field_from_csv(self, parts)
class MessageTypeField(object):
"""A field within a particular message type or subtype"""
def __init__(self, ownername, name, fieldtype):
self.full_name = "{}.{}".format(ownername, name)
self.name = name
self.fieldtype = fieldtype
def missing_fields(self, fields):
"""Return this field if it's not in fields"""
if self.name not in fields and not self.fieldtype.is_optional():
return [self]
return []
def len_fields_bad(self, fieldname, otherfields):
return self.fieldtype.len_fields_bad(fieldname, otherfields)
def __str__(self):
return self.full_name
def __repr__(self):
"""Yuck, but this is what format() uses for lists"""
return self.full_name
class SubtypeType(object):
"""This defines a 'subtype' in BOLT-speak. It consists of fields of
other types. Since 'msgtype' and 'tlvtype' are almost identical, they
inherit from this too.
"""
def __init__(self, name):
self.name = name
self.fields = []
def find_field(self, fieldname):
for f in self.fields:
if f.name == fieldname:
return f
return None
def add_field(self, field):
if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field)
def __str__(self):
return "subtype-{}".format(self.name)
def len_fields_bad(self, fieldname, otherfields):
bad_fields = []
for f in self.fields:
bad_fields += f.len_fields_bad('{}.{}'.format(fieldname, f.name),
otherfields)
return bad_fields
@staticmethod
def type_from_csv(parts):
"""e.g subtype,channel_update_timestamps"""
if len(parts) != 1:
raise ValueError("subtype expected 2 CSV parts, not {}"
.format(parts))
return SubtypeType(parts[0])
def _field_from_csv(self, namespace, parts, ellipsisok=False):
"""Takes msgdata/subtypedata after first two fields
e.g. [...]timestamp_node_id_1,u32,
"""
basetype = namespace.get_type(parts[1])
if not basetype:
raise ValueError('Unknown type {}'.format(parts[1]))
# Fixed number, or another field.
if parts[2] != '':
lenfield = self.find_field(parts[2])
if lenfield is not None:
# If we didn't know that field was a length, we do now!
if type(lenfield.fieldtype) is not LengthFieldType:
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
field = MessageTypeField(self.name, parts[0],
DynamicArrayType(self,
parts[0],
basetype,
lenfield))
lenfield.fieldtype.add_length_for(field)
elif ellipsisok and parts[2] == '...':
field = MessageTypeField(self.name, parts[0],
EllipsisArrayType(self,
parts[0], basetype))
else:
field = MessageTypeField(self.name, parts[0],
SizedArrayType(self,
parts[0], basetype,
int(parts[2])))
else:
field = MessageTypeField(self.name, parts[0], basetype)
return field
def val_from_str(self, s):
if not s.startswith('{'):
raise ValueError("subtype {} must be wrapped in '{{}}': bad {}"
.format(self, s))
s = s[1:]
ret = {}
# FIXME: perhaps allow unlabelled fields to imply assign fields in order?
while not s.startswith('}'):
fieldname, s = s.split('=', 1)
f = self.find_field(fieldname)
if f is None:
raise ValueError("Unknown field name {}".format(fieldname))
ret[fieldname], s = f.fieldtype.val_from_str(s)
if s[0] == ',':
s = s[1:]
# All non-optional fields must be specified.
for f in self.fields:
if not f.fieldtype.is_optional() and f.name not in ret:
raise ValueError("{} missing field {}".format(self, f))
return ret, s[1:]
def _raise_if_badvals(self, v):
# Every non-optional value must be specified, and no others.
defined = set([f.name for f in self.fields])
have = set(v)
unknown = have.difference(defined)
if unknown:
raise ValueError("Unknown fields specified: {}".format(unknown))
for f in defined.difference(have):
if not f.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(f))
def val_to_str(self, v, otherfields):
self._raise_if_badvals(v)
s = ''
sep = ''
for fname, val in v.items():
field = self.find_field(fname)
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
sep = ','
return '{' + s + '}'
def val_to_bin(self, v, otherfields):
self._raise_if_badvals(v)
b = bytes()
for fname, val in v.items():
field = self.find_field(fname)
b += field.fieldtype.val_to_bin(val, otherfields)
return b
def val_from_bin(self, bytestream, otherfields):
totsize = 0
vals = {}
for field in self.fields:
val, size = field.fieldtype.val_from_bin(bytestream[totsize:],
otherfields)
totsize += size
vals[field.name] = val
return vals, totsize
@staticmethod
def field_from_csv(namespace, parts):
"""e.g
subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,"""
if len(parts) != 4:
raise ValueError("subtypedata expected 4 CSV parts, not {}"
.format(parts))
subtype = namespace.get_subtype(parts[0])
if subtype is None:
raise ValueError("unknown subtype {}".format(parts[0]))
field = subtype._field_from_csv(namespace, parts[1:])
if field.fieldtype.only_at_tlv_end():
raise ValueError("{}: cannot have TLV field {}"
.format(subtype, field))
subtype.add_field(field)
class MessageType(SubtypeType):
"""Each MessageType has a specific value, eg 17 is error"""
# * 0x8000 (BADONION): unparsable onion encrypted by sending peer
# * 0x4000 (PERM): permanent failure (otherwise transient)
# * 0x2000 (NODE): node failure (otherwise channel)
# * 0x1000 (UPDATE): new channel update enclosed
onion_types = {'BADONION': 0x8000,
'PERM': 0x4000,
'NODE': 0x2000,
'UPDATE': 0x1000}
def __init__(self, name, value):
super().__init__(name)
self.number = self.parse_value(value)
def parse_value(self, value):
result = 0
for token in value.split('|'):
if token in self.onion_types.keys():
result |= self.onion_types[token]
else:
result |= int(token)
return result
def __str__(self):
return "msgtype-{}".format(self.name)
@staticmethod
def type_from_csv(parts):
"""e.g msgtype,open_channel,32"""
if len(parts) != 2:
raise ValueError("msgtype expected 3 CSV parts, not {}"
.format(parts))
return MessageType(parts[0], parts[1])
@staticmethod
def field_from_csv(namespace, parts):
"""e.g msgdata,open_channel,temporary_channel_id,byte,32"""
if len(parts) != 4:
raise ValueError("msgdata expected 4 CSV parts, not {}"
.format(parts))
messagetype = namespace.get_msgtype(parts[0])
if not messagetype:
raise ValueError("unknown subtype {}".format(parts[0]))
field = messagetype._field_from_csv(namespace, parts[1:])
messagetype.add_field(field)
class TlvStreamType(SubtypeType):
"""A TlvStreamType is just a Subtype, but its fields are
TlvMessageTypes. In the CSV format these are created implicitly, when
a tlvtype line (which defines a TlvMessageType within the TlvType,
confusingly) refers to them.
"""
def __init__(self, name):
super().__init__(name)
def __str__(self):
return "tlvstreamtype-{}".format(self.name)
def find_field_by_number(self, num):
for f in self.fields:
if f.number == num:
return f
return None
def is_optional(self):
"""You can omit a tlvstream= altogether"""
return True
@staticmethod
def type_from_csv(namespace, parts):
"""e.g tlvtype,reply_channel_range_tlvs,timestamps_tlv,1"""
if len(parts) != 3:
raise ValueError("tlvtype expected 4 CSV parts, not {}"
.format(parts))
tlvstream = namespace.get_tlvtype(parts[0])
if not tlvstream:
tlvstream = TlvStreamType(parts[0])
namespace.add_tlvtype(tlvstream)
tlvstream.add_field(TlvMessageType(parts[1], parts[2]))
@staticmethod
def field_from_csv(namespace, parts):
"""e.g
tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
"""
if len(parts) != 5:
raise ValueError("tlvdata expected 6 CSV parts, not {}"
.format(parts))
tlvstream = namespace.get_tlvtype(parts[0])
if not tlvstream:
raise ValueError("unknown tlvtype {}".format(parts[0]))
field = tlvstream.find_field(parts[1])
if field is None:
raise ValueError("Unknown tlv field {}.{}"
.format(tlvstream, parts[1]))
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True)
field.add_field(subfield)
def val_from_str(self, s):
"""{fieldname={...},...}. Returns dict of fieldname->val"""
if not s.startswith('{'):
raise ValueError("tlvtype {} must be wrapped in '{{}}': bad {}"
.format(self, s))
s = s[1:]
ret = {}
while not s.startswith('}'):
fieldname, s = s.split('=', 1)
f = self.find_field(fieldname)
if f is None:
# Unknown fields are number=hexstring
hexstring, s = split_field(s)
ret[int(fieldname)] = bytes.fromhex(hexstring)
else:
ret[fieldname], s = f.val_from_str(s)
if s[0] == ',':
s = s[1:]
return ret, s[1:]
def val_to_str(self, v, otherfields):
s = ''
sep = ''
for fieldname in v:
f = self.find_field(fieldname)
s += sep
if f is None:
s += str(int(fieldname)) + '=' + v[fieldname].hex()
else:
s += f.name + '=' + f.val_to_str(v[fieldname], otherfields)
sep = ','
return '{' + s + '}'
def val_to_bin(self, v, otherfields):
b = bytes()
# If they didn't specify this tlvstream, it's empty.
if v is None:
return b
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires.
def copy_val(val, otherfields):
return val
def get_value(tup):
"""Get value from num, fun, val tuple"""
return tup[0]
ordered = []
for fieldname in v:
f = self.find_field(fieldname)
if f is None:
# fieldname can be an integer for a raw field.
ordered.append((int(fieldname), copy_val, v[fieldname]))
else:
ordered.append((f.number, f.val_to_bin, v[fieldname]))
ordered.sort(key=get_value)
for tup in ordered:
value = tup[1](tup[2], otherfields)
b += (BigSizeType.to_bin(tup[0])
+ BigSizeType.to_bin(len(value))
+ value)
return b
def val_from_bin(self, bytestream, otherfields):
totsize = 0
vals = {}
while totsize < len(bytestream):
tlv_type, size = BigSizeType.from_bin(bytestream[totsize:])
totsize += size
tlv_len, size = BigSizeType.from_bin(bytestream[totsize:])
totsize += size
f = self.find_field_by_number(tlv_type)
if f is None:
vals[tlv_type] = bytestream[totsize:totsize + tlv_len]
size = len(vals[tlv_type])
else:
vals[f.name], size = f.val_from_bin(bytestream
[totsize:totsize
+ tlv_len],
otherfields)
if size != tlv_len:
raise ValueError("Truncated tlv field")
totsize += size
return vals, totsize
def name_and_val(self, name, v):
"""This is overridden by LengthFieldType to return nothing"""
return " {}={}".format(name, self.val_to_str(v, None))
class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak"""
def __init__(self, name, value):
super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class Message(object):
"""A particular message instance"""
def __init__(self, messagetype, **kwargs):
"""MessageType is the type of this msg, with fields. Fields can either be valid values for the type, or if they are strings they are converted according to the field type"""
self.messagetype = messagetype
self.fields = {}
# Convert arguments from strings to values if necessary.
for field in kwargs:
f = self.messagetype.find_field(field)
if f is None:
raise ValueError("Unknown field {}".format(field))
v = kwargs[field]
if isinstance(v, str):
v, remainder = f.fieldtype.val_from_str(v)
if remainder != '':
raise ValueError('Unexpected {} at end of initializer for {}'.format(remainder, field))
self.fields[field] = v
bad_lens = self.messagetype.len_fields_bad(self.messagetype.name,
self.fields)
if bad_lens:
raise ValueError("Inconsistent length fields: {}".format(bad_lens))
def missing_fields(self):
"""Are any required fields missing?"""
missing = []
for ftype in self.messagetype.fields:
missing += ftype.missing_fields(self.fields)
return missing
@staticmethod
def from_bin(namespace, binmsg):
"""Decode a binary wire format to a Message within that namespace"""
typenum = struct.unpack_from(">H", binmsg)[0]
off = 2
mtype = namespace.get_msgtype_by_number(typenum)
if not mtype:
raise ValueError('Unknown message type number {}'.format(typenum))
fields = {}
for f in mtype.fields:
v, size = f.fieldtype.val_from_bin(binmsg[off:], fields)
off += size
fields[f.name] = v
return Message(mtype, **fields)
@staticmethod
def from_str(namespace, s, incomplete_ok=False):
"""Decode a string to a Message within that namespace, of format
msgname [ field=...]*."""
parts = s.split()
mtype = namespace.get_msgtype(parts[0])
if not mtype:
raise ValueError('Unknown message type name {}'.format(parts[0]))
args = {}
for p in parts[1:]:
assign = p.split('=', 1)
args[assign[0]] = assign[1]
m = Message(mtype, **args)
if not incomplete_ok:
missing = m.missing_fields()
if len(missing):
raise ValueError('Missing fields: {}'.format(missing))
return m
def to_bin(self):
"""Encode a Message into its wire format (must not have missing
fields)"""
if self.missing_fields():
raise ValueError('Missing fields: {}'
.format(self.missing_fields()))
ret = struct.pack(">H", self.messagetype.number)
for f in self.messagetype.fields:
# Optional fields get val == None. Usually this means they don't
# write anything, but length fields are an exception: they intuit
# their value from other fields.
if f.name in self.fields:
val = self.fields[f.name]
else:
val = None
ret += f.fieldtype.val_to_bin(val, self.fields)
return ret
def to_str(self):
"""Encode a Message into a string"""
ret = "{}".format(self.messagetype.name)
for f in self.messagetype.fields:
if f.name in self.fields:
ret += f.fieldtype.name_and_val(f.name, self.fields[f.name])
return ret

2
contrib/pyln-proto/setup.py

@ -17,7 +17,7 @@ setup(name='pyln-proto',
author='Christian Decker',
author_email='decker.christian@gmail.com',
license='MIT',
packages=['pyln.proto'],
packages=['pyln.proto', 'pyln.proto.message'],
scripts=[],
zip_safe=True,
install_requires=requirements)

119
contrib/pyln-proto/tests/test_array_types.py

@ -0,0 +1,119 @@
#! /usr/bin/python3
from pyln.proto.message.fundamental_types import fundamental_types
from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType
def test_sized_array():
# Steal two fundamental types for testing
for t in fundamental_types():
if t.name == 'byte':
byte = t
if t.name == 'u16':
u16 = t
if t.name == 'short_channel_id':
scid = t
# Simple class to make outer work.
class dummy:
def __init__(self, name):
self.name = name
for test in [[SizedArrayType(dummy("test1"), "test_arr", byte, 4),
"00010203",
bytes([0, 1, 2, 3])],
[SizedArrayType(dummy("test2"), "test_arr", u16, 4),
"[0,1,2,256]",
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
[SizedArrayType(dummy("test3"), "test_arr", scid, 4),
"[1x2x3,4x5x6,7x8x9,10x11x12]",
bytes([0, 0, 1, 0, 0, 2, 0, 3]
+ [0, 0, 4, 0, 0, 5, 0, 6]
+ [0, 0, 7, 0, 0, 8, 0, 9]
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
v, _ = test[0].val_from_str(test[1])
assert test[0].val_to_str(v, None) == test[1]
v2, _ = test[0].val_from_bin(test[2], None)
assert v2 == v
assert test[0].val_to_bin(v, None) == test[2]
def test_ellipsis_array():
# Steal two fundamental types for testing
for t in fundamental_types():
if t.name == 'byte':
byte = t
if t.name == 'u16':
u16 = t
if t.name == 'short_channel_id':
scid = t
# Simple class to make outer work.
class dummy:
def __init__(self, name):
self.name = name
for test in [[EllipsisArrayType(dummy("test1"), "test_arr", byte),
"00010203",
bytes([0, 1, 2, 3])],
[EllipsisArrayType(dummy("test2"), "test_arr", u16),
"[0,1,2,256]",
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
[EllipsisArrayType(dummy("test3"), "test_arr", scid),
"[1x2x3,4x5x6,7x8x9,10x11x12]",
bytes([0, 0, 1, 0, 0, 2, 0, 3]
+ [0, 0, 4, 0, 0, 5, 0, 6]
+ [0, 0, 7, 0, 0, 8, 0, 9]
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
v, _ = test[0].val_from_str(test[1])
assert test[0].val_to_str(v, None) == test[1]
v2, _ = test[0].val_from_bin(test[2], None)
assert v2 == v
assert test[0].val_to_bin(v, None) == test[2]
def test_dynamic_array():
# Steal two fundamental types for testing
for t in fundamental_types():
if t.name == 'byte':
byte = t
if t.name == 'u16':
u16 = t
if t.name == 'short_channel_id':
scid = t
# Simple class to make outer.
class dummy:
def __init__(self, name):
self.name = name
class field_dummy:
def __init__(self, name, ftype):
self.fieldtype = ftype
self.name = name
lenfield = field_dummy('lenfield', LengthFieldType(u16))
for test in [[DynamicArrayType(dummy("test1"), "test_arr", byte,
lenfield),
"00010203",
bytes([0, 1, 2, 3])],
[DynamicArrayType(dummy("test2"), "test_arr", u16,
lenfield),
"[0,1,2,256]",
bytes([0, 0, 0, 1, 0, 2, 1, 0])],
[DynamicArrayType(dummy("test3"), "test_arr", scid,
lenfield),
"[1x2x3,4x5x6,7x8x9,10x11x12]",
bytes([0, 0, 1, 0, 0, 2, 0, 3]
+ [0, 0, 4, 0, 0, 5, 0, 6]
+ [0, 0, 7, 0, 0, 8, 0, 9]
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
lenfield.fieldtype.add_length_for(field_dummy(test[1], test[0]))
v, _ = test[0].val_from_str(test[1])
otherfields = {test[1]: v}
assert test[0].val_to_str(v, otherfields) == test[1]
v2, _ = test[0].val_from_bin(test[2], otherfields)
assert v2 == v
assert test[0].val_to_bin(v, otherfields) == test[2]
lenfield.fieldtype.len_for = []

74
contrib/pyln-proto/tests/test_fundamental_types.py

@ -0,0 +1,74 @@
#! /usr/bin/python3
from pyln.proto.message.fundamental_types import fundamental_types
def test_fundamental_types():
expect = {'byte': [['255', b'\xff'],
['0', b'\x00']],
'u16': [['65535', b'\xff\xff'],
['0', b'\x00\x00']],
'u32': [['4294967295', b'\xff\xff\xff\xff'],
['0', b'\x00\x00\x00\x00']],
'u64': [['18446744073709551615',
b'\xff\xff\xff\xff\xff\xff\xff\xff'],
['0', b'\x00\x00\x00\x00\x00\x00\x00\x00']],
'tu16': [['65535', b'\xff\xff'],
['256', b'\x01\x00'],
['255', b'\xff'],
['0', b'']],
'tu32': [['4294967295', b'\xff\xff\xff\xff'],
['65536', b'\x01\x00\x00'],
['65535', b'\xff\xff'],
['256', b'\x01\x00'],
['255', b'\xff'],
['0', b'']],
'tu64': [['18446744073709551615',
b'\xff\xff\xff\xff\xff\xff\xff\xff'],
['4294967296', b'\x01\x00\x00\x00\x00'],
['4294967295', b'\xff\xff\xff\xff'],
['65536', b'\x01\x00\x00'],
['65535', b'\xff\xff'],
['256', b'\x01\x00'],
['255', b'\xff'],
['0', b'']],
'chain_hash': [['0102030405060708090a0b0c0d0e0f10'
'1112131415161718191a1b1c1d1e1f20',
bytes(range(1, 33))]],
'channel_id': [['0102030405060708090a0b0c0d0e0f10'
'1112131415161718191a1b1c1d1e1f20',
bytes(range(1, 33))]],
'sha256': [['0102030405060708090a0b0c0d0e0f10'
'1112131415161718191a1b1c1d1e1f20',
bytes(range(1, 33))]],
'signature': [['0102030405060708090a0b0c0d0e0f10'
'1112131415161718191a1b1c1d1e1f20'
'2122232425262728292a2b2c2d2e2f30'
'3132333435363738393a3b3c3d3e3f40',
bytes(range(1, 65))]],
'point': [['02030405060708090a0b0c0d0e0f10'
'1112131415161718191a1b1c1d1e1f20'
'2122',
bytes(range(2, 35))]],
'short_channel_id': [['1x2x3', bytes([0, 0, 1, 0, 0, 2, 0, 3])]],
'bigsize': [['0', bytes([0])],
['252', bytes([252])],
['253', bytes([253, 0, 253])],
['65535', bytes([253, 255, 255])],
['65536', bytes([254, 0, 1, 0, 0])],
['4294967295', bytes([254, 255, 255, 255, 255])],
['4294967296', bytes([255, 0, 0, 0, 1, 0, 0, 0, 0])]],
}
untested = set()
for t in fundamental_types():
if t.name not in expect:
untested.add(t.name)
continue
for test in expect[t.name]:
v, _ = t.val_from_str(test[0])
assert t.val_to_str(v, None) == test[0]
v2, _ = t.val_from_bin(test[1], None)
assert v2 == v
assert t.val_to_bin(v, None) == test[1]
assert untested == set(['varint'])

169
contrib/pyln-proto/tests/test_message.py

@ -0,0 +1,169 @@
#! /usr/bin/python3
from pyln.proto.message import MessageNamespace, Message
import pytest
def test_fundamental():
ns = MessageNamespace()
ns.load_csv(['msgtype,test,1',
'msgdata,test,test_byte,byte,',
'msgdata,test,test_u16,u16,',
'msgdata,test,test_u32,u32,',
'msgdata,test,test_u64,u64,',
'msgdata,test,test_chain_hash,chain_hash,',
'msgdata,test,test_channel_id,channel_id,',
'msgdata,test,test_sha256,sha256,',
'msgdata,test,test_signature,signature,',
'msgdata,test,test_point,point,',
'msgdata,test,test_short_channel_id,short_channel_id,',
])
mstr = """test
test_byte=255
test_u16=65535
test_u32=4294967295
test_u64=18446744073709551615
test_chain_hash=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20
test_channel_id=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20
test_sha256=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20
test_signature=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f40
test_point=0201030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021
test_short_channel_id=1x2x3"""
m = Message.from_str(ns, mstr)
# Same (ignoring whitespace differences)
assert m.to_str().split() == mstr.split()
def test_static_array():
ns = MessageNamespace()
ns.load_csv(['msgtype,test1,1',
'msgdata,test1,test_arr,byte,4'])
ns.load_csv(['msgtype,test2,2',
'msgdata,test2,test_arr,short_channel_id,4'])
for test in [["test1 test_arr=00010203", bytes([0, 1] + [0, 1, 2, 3])],
["test2 test_arr=[0x1x2,4x5x6,7x8x9,10x11x12]",
bytes([0, 2]
+ [0, 0, 0, 0, 0, 1, 0, 2]
+ [0, 0, 4, 0, 0, 5, 0, 6]
+ [0, 0, 7, 0, 0, 8, 0, 9]
+ [0, 0, 10, 0, 0, 11, 0, 12])]]:
m = Message.from_str(ns, test[0])
assert m.to_str() == test[0]
v = m.to_bin()
assert v == test[1]
assert Message.from_bin(ns, test[1]).to_str() == test[0]
def test_subtype():
ns = MessageNamespace()
ns.load_csv(['msgtype,test1,1',
'msgdata,test1,test_sub,channel_update_timestamps,4',
'subtype,channel_update_timestamps',
'subtypedata,'
+ 'channel_update_timestamps,timestamp_node_id_1,u32,',
'subtypedata,'
+ 'channel_update_timestamps,timestamp_node_id_2,u32,'])
for test in [["test1 test_sub=["
"{timestamp_node_id_1=1,timestamp_node_id_2=2}"
",{timestamp_node_id_1=3,timestamp_node_id_2=4}"
",{timestamp_node_id_1=5,timestamp_node_id_2=6}"
",{timestamp_node_id_1=7,timestamp_node_id_2=8}]",
bytes([0, 1]
+ [0, 0, 0, 1, 0, 0, 0, 2]
+ [0, 0, 0, 3, 0, 0, 0, 4]
+ [0, 0, 0, 5, 0, 0, 0, 6]
+ [0, 0, 0, 7, 0, 0, 0, 8])]]:
m = Message.from_str(ns, test[0])
assert m.to_str() == test[0]
v = m.to_bin()
assert v == test[1]
assert Message.from_bin(ns, test[1]).to_str() == test[0]
# Test missing field logic.
m = Message.from_str(ns, "test1", incomplete_ok=True)
assert m.missing_fields()
def test_tlv():
ns = MessageNamespace()
ns.load_csv(['msgtype,test1,1',
'msgdata,test1,tlvs,test_tlvstream,',
'tlvtype,test_tlvstream,tlv1,1',
'tlvdata,test_tlvstream,tlv1,field1,byte,4',
'tlvdata,test_tlvstream,tlv1,field2,u32,',
'tlvtype,test_tlvstream,tlv2,255',
'tlvdata,test_tlvstream,tlv2,field3,byte,...'])
for test in [["test1 tlvs={tlv1={field1=01020304,field2=5}}",
bytes([0, 1]
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5])],
["test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304}}",
bytes([0, 1]
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
+ [253, 0, 255, 4, 1, 2, 3, 4])],
["test1 tlvs={tlv1={field1=01020304,field2=5},4=010203,tlv2={field3=01020304}}",
bytes([0, 1]
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
+ [4, 3, 1, 2, 3]
+ [253, 0, 255, 4, 1, 2, 3, 4])]]:
m = Message.from_str(ns, test[0])
assert m.to_str() == test[0]
v = m.to_bin()
assert v == test[1]
assert Message.from_bin(ns, test[1]).to_str() == test[0]
# Ordering test (turns into canonical ordering)
m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}')
assert m.to_bin() == bytes([0, 1]
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
+ [4, 3, 1, 2, 3]
+ [253, 0, 255, 4, 1, 2, 3, 4])
def test_message_constructor():
ns = MessageNamespace(['msgtype,test1,1',
'msgdata,test1,tlvs,test_tlvstream,',
'tlvtype,test_tlvstream,tlv1,1',
'tlvdata,test_tlvstream,tlv1,field1,byte,4',
'tlvdata,test_tlvstream,tlv1,field2,u32,',
'tlvtype,test_tlvstream,tlv2,255',
'tlvdata,test_tlvstream,tlv2,field3,byte,...'])
m = Message(ns.get_msgtype('test1'),
tlvs='{tlv1={field1=01020304,field2=5}'
',tlv2={field3=01020304},4=010203}')
assert m.to_bin() == bytes([0, 1]
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
+ [4, 3, 1, 2, 3]
+ [253, 0, 255, 4, 1, 2, 3, 4])
def test_dynamic_array():
"""Test that dynamic array types enforce matching lengths"""
ns = MessageNamespace(['msgtype,test1,1',
'msgdata,test1,count,u16,',
'msgdata,test1,arr1,byte,count',
'msgdata,test1,arr2,u32,count'])
# This one is fine.
m = Message(ns.get_msgtype('test1'),
arr1='01020304', arr2='[1,2,3,4]')
assert m.to_bin() == bytes([0, 1]
+ [0, 4]
+ [1, 2, 3, 4]
+ [0, 0, 0, 1,
0, 0, 0, 2,
0, 0, 0, 3,
0, 0, 0, 4])
# These ones are not
with pytest.raises(ValueError, match='Inconsistent length.*count'):
m = Message(ns.get_msgtype('test1'),
arr1='01020304', arr2='[1,2,3]')
with pytest.raises(ValueError, match='Inconsistent length.*count'):
m = Message(ns.get_msgtype('test1'),
arr1='01020304', arr2='[1,2,3,4,5]')
Loading…
Cancel
Save