Browse Source

pyln.proto.message: use BufferedIOBase instead of bytes for binary ops.

Instead of val_to_bin/val_from_bin which deal with bytes, we implement
read and write which use streams.  This simplifies the API. 

Suggested-by: Christian Decker
Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
nifty/pset-pre
Rusty Russell 5 years ago
committed by Christian Decker
parent
commit
47631cc23c
  1. 62
      contrib/pyln-proto/pyln/proto/message/array_types.py
  2. 103
      contrib/pyln-proto/pyln/proto/message/fundamental_types.py
  3. 123
      contrib/pyln-proto/pyln/proto/message/message.py
  4. 1
      contrib/pyln-proto/requirements.txt
  5. 85
      contrib/pyln-proto/tests/test_array_types.py
  6. 7
      contrib/pyln-proto/tests/test_fundamental_types.py
  7. 58
      contrib/pyln-proto/tests/test_message.py

62
contrib/pyln-proto/pyln/proto/message/array_types.py

@ -42,28 +42,26 @@ wants an array of some type.
return '[' + s + ']' return '[' + s + ']'
def val_to_bin(self, v, otherfields): def write(self, io_out, v, otherfields):
b = bytes()
for i in v: for i in v:
b += self.elemtype.val_to_bin(i, otherfields) self.elemtype.write(io_out, i, otherfields)
return b
def arr_from_bin(self, bytestream, otherfields, arraysize): def read_arr(self, io_in, otherfields, arraysize):
"""arraysize None means take rest of bytestream exactly""" """arraysize None means take rest of io entirely and exactly"""
totsize = 0
vals = [] vals = []
i = 0 while arraysize is None or len(vals) < arraysize:
while True: # Throws an exception on partial read, so None means completely empty.
if arraysize is None and totsize == len(bytestream): val = self.elemtype.read(io_in, otherfields)
return vals, totsize if val is None:
elif i == arraysize: if arraysize is not None:
return vals, totsize raise ValueError('{}: not enough remaining to read'
val, size = self.elemtype.val_from_bin(bytestream[totsize:], .format(self))
otherfields) break
totsize += size
i += 1
vals.append(val) vals.append(val)
return vals
class SizedArrayType(ArrayType): class SizedArrayType(ArrayType):
"""A fixed-size array""" """A fixed-size array"""
@ -82,13 +80,13 @@ class SizedArrayType(ArrayType):
raise ValueError("Length of {} != {}", s, self.arraysize) raise ValueError("Length of {} != {}", s, self.arraysize)
return a, b return a, b
def val_to_bin(self, v, otherfields): def write(self, io_out, v, otherfields):
if len(v) != self.arraysize: if len(v) != self.arraysize:
raise ValueError("Length of {} != {}", v, self.arraysize) raise ValueError("Length of {} != {}", v, self.arraysize)
return super().val_to_bin(v, otherfields) return super().write(io_out, v, otherfields)
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
return super().arr_from_bin(bytestream, otherfields, self.arraysize) return super().read_arr(io_in, otherfields, self.arraysize)
class EllipsisArrayType(ArrayType): class EllipsisArrayType(ArrayType):
@ -97,9 +95,9 @@ when the tlv ends"""
def __init__(self, tlv, name, elemtype): def __init__(self, tlv, name, elemtype):
super().__init__(tlv, name, elemtype) super().__init__(tlv, name, elemtype)
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
"""Takes rest of bytestream""" """Takes rest of bytestream"""
return super().arr_from_bin(bytestream, otherfields, None) return super().read_arr(io_in, otherfields, None)
def only_at_tlv_end(self): def only_at_tlv_end(self):
"""These only make sense at the end of a TLV""" """These only make sense at the end of a TLV"""
@ -142,10 +140,6 @@ class LengthFieldType(FieldType):
return v return v
return self.calc_value(otherfields) return self.calc_value(otherfields)
def val_to_bin(self, _, otherfields):
return self.underlying_type.val_to_bin(self.calc_value(otherfields),
otherfields)
def val_to_str(self, _, otherfields): def val_to_str(self, _, otherfields):
return self.underlying_type.val_to_str(self.calc_value(otherfields), return self.underlying_type.val_to_str(self.calc_value(otherfields),
otherfields) otherfields)
@ -155,9 +149,13 @@ class LengthFieldType(FieldType):
they're implied by the length of other fields""" they're implied by the length of other fields"""
return '' return ''
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)""" """We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
return self.underlying_type.val_from_bin(bytestream, otherfields) return self.underlying_type.read(io_in, otherfields)
def write(self, io_out, _, otherfields):
self.underlying_type.write(io_out, self.calc_value(otherfields),
otherfields)
def val_from_str(self, s): def val_from_str(self, s):
raise ValueError('{} is implied, cannot be specified'.format(self)) raise ValueError('{} is implied, cannot be specified'.format(self))
@ -182,6 +180,6 @@ class DynamicArrayType(ArrayType):
assert type(lenfield.fieldtype) is LengthFieldType assert type(lenfield.fieldtype) is LengthFieldType
self.lenfield = lenfield self.lenfield = lenfield
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
return super().arr_from_bin(bytestream, otherfields, return super().read_arr(io_in, otherfields,
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))

103
contrib/pyln-proto/pyln/proto/message/fundamental_types.py

@ -1,4 +1,22 @@
import struct import struct
import io
from typing import Optional
def try_unpack(name: str,
io_out: io.BufferedIOBase,
structfmt: str,
empty_ok: bool) -> Optional[int]:
"""Unpack a single value using struct.unpack.
If need_all, never return None, otherwise returns None if EOF."""
b = io_out.read(struct.calcsize(structfmt))
if len(b) == 0 and empty_ok:
return None
elif len(b) < struct.calcsize(structfmt):
raise ValueError("{}: not enough bytes", name)
return struct.unpack(structfmt, b)[0]
def split_field(s): def split_field(s):
@ -57,15 +75,11 @@ class IntegerType(FieldType):
a, b = split_field(s) a, b = split_field(s)
return int(a), b return int(a), b
def val_to_bin(self, v, otherfields): def write(self, io_out, v, otherfields):
return struct.pack(self.structfmt, v) io_out.write(struct.pack(self.structfmt, v))
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
"Returns value, bytesused" return try_unpack(self.name, io_in, self.structfmt, empty_ok=True)
if self.bytelen > len(bytestream):
raise ValueError('{}: not enough remaining to read'.format(self))
return struct.unpack_from(self.structfmt,
bytestream)[0], self.bytelen
class ShortChannelIDType(IntegerType): class ShortChannelIDType(IntegerType):
@ -110,30 +124,24 @@ class TruncatedIntType(FieldType):
.format(a, self.name)) .format(a, self.name))
return int(a), b return int(a), b
def val_to_bin(self, v, otherfields): def write(self, io_out, v, otherfields):
binval = struct.pack('>Q', v) binval = struct.pack('>Q', v)
while len(binval) != 0 and binval[0] == 0: while len(binval) != 0 and binval[0] == 0:
binval = binval[1:] binval = binval[1:]
if len(binval) > self.maxbytes: if len(binval) > self.maxbytes:
raise ValueError('{} exceeds maximum {} capacity' raise ValueError('{} exceeds maximum {} capacity'
.format(v, self.name)) .format(v, self.name))
return binval io_out.write(binval)
def val_from_bin(self, bytestream, otherfields):
"Returns value, bytesused"
binval = bytes()
while len(binval) < len(bytestream):
if len(binval) == 0 and bytestream[len(binval)] == 0:
raise ValueError('{} encoding is not minimal: {}'
.format(self.name, bytestream))
binval += bytes([bytestream[len(binval)]])
def read(self, io_in, otherfields):
binval = io_in.read()
if len(binval) > self.maxbytes: if len(binval) > self.maxbytes:
raise ValueError('{} is too long for {}'.format(binval, self.name)) raise ValueError('{} is too long for {}'.format(binval, self.name))
if len(binval) > 0 and binval[0] == 0:
raise ValueError('{} encoding is not minimal: {}'
.format(self.name, binval))
# Pad with zeroes and convert as u64 # Pad with zeroes and convert as u64
return (struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0], return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0]
len(binval))
class FundamentalHexType(FieldType): class FundamentalHexType(FieldType):
@ -154,16 +162,18 @@ class FundamentalHexType(FieldType):
raise ValueError("Length of {} != {}", a, self.bytelen) raise ValueError("Length of {} != {}", a, self.bytelen)
return ret, b return ret, b
def val_to_bin(self, v, otherfields): def write(self, io_out, v, otherfields):
if len(bytes(v)) != self.bytelen: if len(bytes(v)) != self.bytelen:
raise ValueError("Length of {} != {}", v, self.bytelen) raise ValueError("Length of {} != {}", v, self.bytelen)
return bytes(v) io_out.write(v)
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
"Returns value, size from bytestream" val = io_in.read(self.bytelen)
if self.bytelen > len(bytestream): if len(val) == 0:
return None
elif len(val) != self.bytelen:
raise ValueError('{}: not enough remaining'.format(self)) raise ValueError('{}: not enough remaining'.format(self))
return bytestream[:self.bytelen], self.bytelen return val
class BigSizeType(FieldType): class BigSizeType(FieldType):
@ -177,37 +187,34 @@ class BigSizeType(FieldType):
# For the convenience of TLV header parsing # For the convenience of TLV header parsing
@staticmethod @staticmethod
def to_bin(v): def write(io_out, v, otherfields=None):
if v < 253: if v < 253:
return bytes([v]) io_out.write(bytes([v]))
elif v < 2**16: elif v < 2**16:
return bytes([253]) + struct.pack('>H', v) io_out.write(bytes([253]) + struct.pack('>H', v))
elif v < 2**32: elif v < 2**32:
return bytes([254]) + struct.pack('>I', v) io_out.write(bytes([254]) + struct.pack('>I', v))
else: else:
return bytes([255]) + struct.pack('>Q', v) io_out.write(bytes([255]) + struct.pack('>Q', v))
@staticmethod @staticmethod
def from_bin(bytestream): def read(io_in, otherfields=None):
"Returns value, bytesused" "Returns value, or None on EOF"
if bytestream[0] < 253: b = io_in.read(1)
return int(bytestream[0]), 1 if len(b) == 0:
elif bytestream[0] == 253: return None
return struct.unpack_from('>H', bytestream[1:])[0], 3 if b[0] < 253:
elif bytestream[0] == 254: return int(b[0])
return struct.unpack_from('>I', bytestream[1:])[0], 5 elif b[0] == 253:
return try_unpack('BigSize', io_in, '>H', empty_ok=False)
elif b[0] == 254:
return try_unpack('BigSize', io_in, '>I', empty_ok=False)
else: else:
return struct.unpack_from('>Q', bytestream[1:])[0], 9 return try_unpack('BigSize', io_in, '>Q', empty_ok=False)
def val_to_str(self, v, otherfields): def val_to_str(self, v, otherfields):
return "{}".format(int(v)) return "{}".format(int(v))
def val_to_bin(self, v, otherfields):
return self.to_bin(v)
def val_from_bin(self, bytestream, otherfields):
return self.from_bin(bytestream)
def fundamental_types(): def fundamental_types():
# From 01-messaging.md#fundamental-types: # From 01-messaging.md#fundamental-types:

123
contrib/pyln-proto/pyln/proto/message/message.py

@ -1,5 +1,6 @@
import struct import struct
from .fundamental_types import fundamental_types, BigSizeType, split_field import io
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack
from .array_types import ( from .array_types import (
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
) )
@ -253,24 +254,21 @@ inherit from this too.
return '{' + s + '}' return '{' + s + '}'
def val_to_bin(self, v, otherfields): def write(self, io_out, v, otherfields):
self._raise_if_badvals(v) self._raise_if_badvals(v)
b = bytes()
for fname, val in v.items(): for fname, val in v.items():
field = self.find_field(fname) field = self.find_field(fname)
b += field.fieldtype.val_to_bin(val, otherfields) field.fieldtype.write(io_out, val, otherfields)
return b
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
totsize = 0
vals = {} vals = {}
for field in self.fields: for field in self.fields:
val, size = field.fieldtype.val_from_bin(bytestream[totsize:], val = field.fieldtype.read(io_in, otherfields)
otherfields) if val is None:
totsize += size raise ValueError("{}.{}: short read".format(self, field))
vals[field.name] = val vals[field.name] = val
return vals, totsize return vals
@staticmethod @staticmethod
def field_from_csv(namespace, parts): def field_from_csv(namespace, parts):
@ -433,17 +431,15 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return '{' + s + '}' return '{' + s + '}'
def val_to_bin(self, v, otherfields): def write(self, iobuf, v, otherfields):
b = bytes()
# If they didn't specify this tlvstream, it's empty. # If they didn't specify this tlvstream, it's empty.
if v is None: if v is None:
return b return
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into # Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires. # ascending order as TLV spec requires.
def copy_val(val, otherfields): def write_raw_val(iobuf, val, otherfields):
return val iobuf.write(val)
def get_value(tup): def get_value(tup):
"""Get value from num, fun, val tuple""" """Get value from num, fun, val tuple"""
@ -454,43 +450,40 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
f = self.find_field(fieldname) f = self.find_field(fieldname)
if f is None: if f is None:
# fieldname can be an integer for a raw field. # fieldname can be an integer for a raw field.
ordered.append((int(fieldname), copy_val, v[fieldname])) ordered.append((int(fieldname), write_raw_val, v[fieldname]))
else: else:
ordered.append((f.number, f.val_to_bin, v[fieldname])) ordered.append((f.number, f.write, v[fieldname]))
ordered.sort(key=get_value) ordered.sort(key=get_value)
for tup in ordered: for typenum, writefunc, val in ordered:
value = tup[1](tup[2], otherfields) buf = io.BytesIO()
b += (BigSizeType.to_bin(tup[0]) writefunc(buf, val, otherfields)
+ BigSizeType.to_bin(len(value)) BigSizeType.write(iobuf, typenum)
+ value) BigSizeType.write(iobuf, len(buf.getvalue()))
iobuf.write(buf.getvalue())
return b
def val_from_bin(self, bytestream, otherfields): def read(self, io_in, otherfields):
totsize = 0
vals = {} vals = {}
while totsize < len(bytestream): while True:
tlv_type, size = BigSizeType.from_bin(bytestream[totsize:]) tlv_type = BigSizeType.read(io_in)
totsize += size if tlv_type is None:
tlv_len, size = BigSizeType.from_bin(bytestream[totsize:]) return vals
totsize += size
tlv_len = BigSizeType.read(io_in)
if tlv_len is None:
raise ValueError("{}: truncated tlv_len field".format(self))
binval = io_in.read(tlv_len)
if len(binval) != tlv_len:
raise ValueError("{}: truncated tlv {} value"
.format(tlv_type, self))
f = self.find_field_by_number(tlv_type) f = self.find_field_by_number(tlv_type)
if f is None: if f is None:
vals[tlv_type] = bytestream[totsize:totsize + tlv_len] # Raw fields are allowed, just index by number.
size = len(vals[tlv_type]) vals[tlv_type] = binval
else: else:
vals[f.name], size = f.val_from_bin(bytestream vals[f.name] = f.read(io.BytesIO(binval), otherfields)
[totsize:totsize
+ tlv_len],
otherfields)
if size != tlv_len:
raise ValueError("Truncated tlv field")
totsize += size
return vals, totsize
def name_and_val(self, name, v): def name_and_val(self, name, v):
"""This is overridden by LengthFieldType to return nothing""" """This is overridden by LengthFieldType to return nothing"""
@ -541,10 +534,15 @@ class Message(object):
return missing return missing
@staticmethod @staticmethod
def from_bin(namespace, binmsg): def read(namespace, io_in):
"""Decode a binary wire format to a Message within that namespace""" """Read and decode a Message within that namespace.
typenum = struct.unpack_from(">H", binmsg)[0]
off = 2 Returns None on EOF
"""
typenum = try_unpack('message_type', io_in, ">H", empty_ok=True)
if typenum is None:
return None
mtype = namespace.get_msgtype_by_number(typenum) mtype = namespace.get_msgtype_by_number(typenum)
if not mtype: if not mtype:
@ -552,16 +550,21 @@ class Message(object):
fields = {} fields = {}
for f in mtype.fields: for f in mtype.fields:
v, size = f.fieldtype.val_from_bin(binmsg[off:], fields) fields[f.name] = f.fieldtype.read(io_in, fields)
off += size if fields[f.name] is None:
fields[f.name] = v # optional fields are OK to be missing at end!
raise ValueError('{}: truncated at field {}'
.format(mtype, f.name))
return Message(mtype, **fields) return Message(mtype, **fields)
@staticmethod @staticmethod
def from_str(namespace, s, incomplete_ok=False): def from_str(namespace, s, incomplete_ok=False):
"""Decode a string to a Message within that namespace, of format """Decode a string to a Message within that namespace.
msgname [ field=...]*."""
Format is msgname [ field=...]*.
"""
parts = s.split() parts = s.split()
mtype = namespace.get_msgtype(parts[0]) mtype = namespace.get_msgtype(parts[0])
@ -582,14 +585,17 @@ msgname [ field=...]*."""
return m return m
def to_bin(self): def write(self, io_out):
"""Encode a Message into its wire format (must not have missing """Write a Message into its wire format.
fields)"""
Must not have missing fields.
"""
if self.missing_fields(): if self.missing_fields():
raise ValueError('Missing fields: {}' raise ValueError('Missing fields: {}'
.format(self.missing_fields())) .format(self.missing_fields()))
ret = struct.pack(">H", self.messagetype.number) io_out.write(struct.pack(">H", self.messagetype.number))
for f in self.messagetype.fields: for f in self.messagetype.fields:
# Optional fields get val == None. Usually this means they don't # Optional fields get val == None. Usually this means they don't
# write anything, but length fields are an exception: they intuit # write anything, but length fields are an exception: they intuit
@ -598,8 +604,7 @@ fields)"""
val = self.fields[f.name] val = self.fields[f.name]
else: else:
val = None val = None
ret += f.fieldtype.val_to_bin(val, self.fields) f.fieldtype.write(io_out, val, self.fields)
return ret
def to_str(self): def to_str(self):
"""Encode a Message into a string""" """Encode a Message into a string"""

1
contrib/pyln-proto/requirements.txt

@ -2,3 +2,4 @@ bitstring==3.1.6
cryptography==2.8 cryptography==2.8
coincurve==13.0.0 coincurve==13.0.0
base58==1.0.2 base58==1.0.2
mypy

85
contrib/pyln-proto/tests/test_array_types.py

@ -1,6 +1,7 @@
#! /usr/bin/python3 #! /usr/bin/python3
from pyln.proto.message.fundamental_types import fundamental_types from pyln.proto.message.fundamental_types import fundamental_types
from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType
import io
def test_sized_array(): def test_sized_array():
@ -32,9 +33,11 @@ def test_sized_array():
+ [0, 0, 10, 0, 0, 11, 0, 12])]]: + [0, 0, 10, 0, 0, 11, 0, 12])]]:
v, _ = arrtype.val_from_str(s) v, _ = arrtype.val_from_str(s)
assert arrtype.val_to_str(v, None) == s assert arrtype.val_to_str(v, None) == s
v2, _ = arrtype.val_from_bin(b, None) v2 = arrtype.read(io.BytesIO(b), None)
assert v2 == v assert v2 == v
assert arrtype.val_to_bin(v, None) == b buf = io.BytesIO()
arrtype.write(buf, v, None)
assert buf.getvalue() == b
def test_ellipsis_array(): def test_ellipsis_array():
@ -52,23 +55,25 @@ def test_ellipsis_array():
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
for test in [[EllipsisArrayType(dummy("test1"), "test_arr", byte), for arrtype, s, b in [[EllipsisArrayType(dummy("test1"), "test_arr", byte),
"00010203", "00010203",
bytes([0, 1, 2, 3])], bytes([0, 1, 2, 3])],
[EllipsisArrayType(dummy("test2"), "test_arr", u16), [EllipsisArrayType(dummy("test2"), "test_arr", u16),
"[0,1,2,256]", "[0,1,2,256]",
bytes([0, 0, 0, 1, 0, 2, 1, 0])], bytes([0, 0, 0, 1, 0, 2, 1, 0])],
[EllipsisArrayType(dummy("test3"), "test_arr", scid), [EllipsisArrayType(dummy("test3"), "test_arr", scid),
"[1x2x3,4x5x6,7x8x9,10x11x12]", "[1x2x3,4x5x6,7x8x9,10x11x12]",
bytes([0, 0, 1, 0, 0, 2, 0, 3] bytes([0, 0, 1, 0, 0, 2, 0, 3]
+ [0, 0, 4, 0, 0, 5, 0, 6] + [0, 0, 4, 0, 0, 5, 0, 6]
+ [0, 0, 7, 0, 0, 8, 0, 9] + [0, 0, 7, 0, 0, 8, 0, 9]
+ [0, 0, 10, 0, 0, 11, 0, 12])]]: + [0, 0, 10, 0, 0, 11, 0, 12])]]:
v, _ = test[0].val_from_str(test[1]) v, _ = arrtype.val_from_str(s)
assert test[0].val_to_str(v, None) == test[1] assert arrtype.val_to_str(v, None) == s
v2, _ = test[0].val_from_bin(test[2], None) v2 = arrtype.read(io.BytesIO(b), None)
assert v2 == v assert v2 == v
assert test[0].val_to_bin(v, None) == test[2] buf = io.BytesIO()
arrtype.write(buf, v, None)
assert buf.getvalue() == b
def test_dynamic_array(): def test_dynamic_array():
@ -93,27 +98,29 @@ def test_dynamic_array():
lenfield = field_dummy('lenfield', LengthFieldType(u16)) lenfield = field_dummy('lenfield', LengthFieldType(u16))
for test in [[DynamicArrayType(dummy("test1"), "test_arr", byte, for arrtype, s, b in [[DynamicArrayType(dummy("test1"), "test_arr", byte,
lenfield), lenfield),
"00010203", "00010203",
bytes([0, 1, 2, 3])], bytes([0, 1, 2, 3])],
[DynamicArrayType(dummy("test2"), "test_arr", u16, [DynamicArrayType(dummy("test2"), "test_arr", u16,
lenfield), lenfield),
"[0,1,2,256]", "[0,1,2,256]",
bytes([0, 0, 0, 1, 0, 2, 1, 0])], bytes([0, 0, 0, 1, 0, 2, 1, 0])],
[DynamicArrayType(dummy("test3"), "test_arr", scid, [DynamicArrayType(dummy("test3"), "test_arr", scid,
lenfield), lenfield),
"[1x2x3,4x5x6,7x8x9,10x11x12]", "[1x2x3,4x5x6,7x8x9,10x11x12]",
bytes([0, 0, 1, 0, 0, 2, 0, 3] bytes([0, 0, 1, 0, 0, 2, 0, 3]
+ [0, 0, 4, 0, 0, 5, 0, 6] + [0, 0, 4, 0, 0, 5, 0, 6]
+ [0, 0, 7, 0, 0, 8, 0, 9] + [0, 0, 7, 0, 0, 8, 0, 9]
+ [0, 0, 10, 0, 0, 11, 0, 12])]]: + [0, 0, 10, 0, 0, 11, 0, 12])]]:
lenfield.fieldtype.add_length_for(field_dummy(test[1], test[0])) lenfield.fieldtype.add_length_for(field_dummy(s, arrtype))
v, _ = test[0].val_from_str(test[1]) v, _ = arrtype.val_from_str(s)
otherfields = {test[1]: v} otherfields = {s: v}
assert test[0].val_to_str(v, otherfields) == test[1] assert arrtype.val_to_str(v, otherfields) == s
v2, _ = test[0].val_from_bin(test[2], otherfields) v2 = arrtype.read(io.BytesIO(b), otherfields)
assert v2 == v assert v2 == v
assert test[0].val_to_bin(v, otherfields) == test[2] buf = io.BytesIO()
arrtype.write(buf, v, None)
assert buf.getvalue() == b
lenfield.fieldtype.len_for = [] lenfield.fieldtype.len_for = []

7
contrib/pyln-proto/tests/test_fundamental_types.py

@ -1,5 +1,6 @@
#! /usr/bin/python3 #! /usr/bin/python3
from pyln.proto.message.fundamental_types import fundamental_types from pyln.proto.message.fundamental_types import fundamental_types
import io
def test_fundamental_types(): def test_fundamental_types():
@ -67,8 +68,10 @@ def test_fundamental_types():
for test in expect[t.name]: for test in expect[t.name]:
v, _ = t.val_from_str(test[0]) v, _ = t.val_from_str(test[0])
assert t.val_to_str(v, None) == test[0] assert t.val_to_str(v, None) == test[0]
v2, _ = t.val_from_bin(test[1], None) v2 = t.read(io.BytesIO(test[1]), None)
assert v2 == v assert v2 == v
assert t.val_to_bin(v, None) == test[1] buf = io.BytesIO()
t.write(buf, v, None)
assert buf.getvalue() == test[1]
assert untested == set(['varint']) assert untested == set(['varint'])

58
contrib/pyln-proto/tests/test_message.py

@ -1,6 +1,7 @@
#! /usr/bin/python3 #! /usr/bin/python3
from pyln.proto.message import MessageNamespace, Message from pyln.proto.message import MessageNamespace, Message
import pytest import pytest
import io
def test_fundamental(): def test_fundamental():
@ -51,9 +52,10 @@ def test_static_array():
+ [0, 0, 10, 0, 0, 11, 0, 12])]]: + [0, 0, 10, 0, 0, 11, 0, 12])]]:
m = Message.from_str(ns, test[0]) m = Message.from_str(ns, test[0])
assert m.to_str() == test[0] assert m.to_str() == test[0]
v = m.to_bin() buf = io.BytesIO()
assert v == test[1] m.write(buf)
assert Message.from_bin(ns, test[1]).to_str() == test[0] assert buf.getvalue() == test[1]
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
def test_subtype(): def test_subtype():
@ -78,9 +80,10 @@ def test_subtype():
+ [0, 0, 0, 7, 0, 0, 0, 8])]]: + [0, 0, 0, 7, 0, 0, 0, 8])]]:
m = Message.from_str(ns, test[0]) m = Message.from_str(ns, test[0])
assert m.to_str() == test[0] assert m.to_str() == test[0]
v = m.to_bin() buf = io.BytesIO()
assert v == test[1] m.write(buf)
assert Message.from_bin(ns, test[1]).to_str() == test[0] assert buf.getvalue() == test[1]
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
# Test missing field logic. # Test missing field logic.
m = Message.from_str(ns, "test1", incomplete_ok=True) m = Message.from_str(ns, "test1", incomplete_ok=True)
@ -111,16 +114,19 @@ def test_tlv():
+ [253, 0, 255, 4, 1, 2, 3, 4])]]: + [253, 0, 255, 4, 1, 2, 3, 4])]]:
m = Message.from_str(ns, test[0]) m = Message.from_str(ns, test[0])
assert m.to_str() == test[0] assert m.to_str() == test[0]
v = m.to_bin() buf = io.BytesIO()
assert v == test[1] m.write(buf)
assert Message.from_bin(ns, test[1]).to_str() == test[0] assert buf.getvalue() == test[1]
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
# Ordering test (turns into canonical ordering) # Ordering test (turns into canonical ordering)
m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}') m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}')
assert m.to_bin() == bytes([0, 1] buf = io.BytesIO()
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] m.write(buf)
+ [4, 3, 1, 2, 3] assert buf.getvalue() == bytes([0, 1]
+ [253, 0, 255, 4, 1, 2, 3, 4]) + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
+ [4, 3, 1, 2, 3]
+ [253, 0, 255, 4, 1, 2, 3, 4])
def test_message_constructor(): def test_message_constructor():
@ -135,10 +141,12 @@ def test_message_constructor():
m = Message(ns.get_msgtype('test1'), m = Message(ns.get_msgtype('test1'),
tlvs='{tlv1={field1=01020304,field2=5}' tlvs='{tlv1={field1=01020304,field2=5}'
',tlv2={field3=01020304},4=010203}') ',tlv2={field3=01020304},4=010203}')
assert m.to_bin() == bytes([0, 1] buf = io.BytesIO()
+ [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] m.write(buf)
+ [4, 3, 1, 2, 3] assert buf.getvalue() == bytes([0, 1]
+ [253, 0, 255, 4, 1, 2, 3, 4]) + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
+ [4, 3, 1, 2, 3]
+ [253, 0, 255, 4, 1, 2, 3, 4])
def test_dynamic_array(): def test_dynamic_array():
@ -151,13 +159,15 @@ def test_dynamic_array():
# This one is fine. # This one is fine.
m = Message(ns.get_msgtype('test1'), m = Message(ns.get_msgtype('test1'),
arr1='01020304', arr2='[1,2,3,4]') arr1='01020304', arr2='[1,2,3,4]')
assert m.to_bin() == bytes([0, 1] buf = io.BytesIO()
+ [0, 4] m.write(buf)
+ [1, 2, 3, 4] assert buf.getvalue() == bytes([0, 1]
+ [0, 0, 0, 1, + [0, 4]
0, 0, 0, 2, + [1, 2, 3, 4]
0, 0, 0, 3, + [0, 0, 0, 1,
0, 0, 0, 4]) 0, 0, 0, 2,
0, 0, 0, 3,
0, 0, 0, 4])
# These ones are not # These ones are not
with pytest.raises(ValueError, match='Inconsistent length.*count'): with pytest.raises(ValueError, match='Inconsistent length.*count'):

Loading…
Cancel
Save