Browse Source

pyln.proto.message: more mypy fixes.

This includes some real bugfixes, since it noticed some places we were
being loose with different types!

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
paymod-01
Rusty Russell 5 years ago
committed by Christian Decker
parent
commit
11a0de877e
  1. 2
      contrib/pyln-proto/Makefile
  2. 16
      contrib/pyln-proto/pyln/proto/message/array_types.py
  3. 13
      contrib/pyln-proto/pyln/proto/message/fundamental_types.py
  4. 152
      contrib/pyln-proto/pyln/proto/message/message.py

2
contrib/pyln-proto/Makefile

@ -17,7 +17,7 @@ check-flake8:
# mypy . does not recurse. I have no idea why... # mypy . does not recurse. I have no idea why...
check-mypy: check-mypy:
mypy --ignore-missing-imports `find * -name '*.py'` mypy --ignore-missing-imports `find pyln/proto/message/ -name '*.py'`
$(SDIST_FILE): $(SDIST_FILE):
python3 setup.py sdist python3 setup.py sdist

16
contrib/pyln-proto/pyln/proto/message/array_types.py

@ -1,8 +1,8 @@
from .fundamental_types import FieldType, IntegerType, split_field from .fundamental_types import FieldType, IntegerType, split_field
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union, cast
from io import BufferedIOBase from io import BufferedIOBase
if TYPE_CHECKING: if TYPE_CHECKING:
from .message import SubtypeType, TlvStreamType from .message import SubtypeType, TlvMessageType, MessageTypeField
class ArrayType(FieldType): class ArrayType(FieldType):
@ -98,7 +98,7 @@ class SizedArrayType(ArrayType):
class EllipsisArrayType(ArrayType): class EllipsisArrayType(ArrayType):
"""This is used for ... fields at the end of a tlv: the array ends """This is used for ... fields at the end of a tlv: the array ends
when the tlv ends""" when the tlv ends"""
def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType): def __init__(self, tlv: 'TlvMessageType', name: str, elemtype: FieldType):
super().__init__(tlv, name, elemtype) super().__init__(tlv, name, elemtype)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
@ -119,13 +119,13 @@ class LengthFieldType(FieldType):
super().__init__(inttype.name) super().__init__(inttype.name)
self.underlying_type = inttype self.underlying_type = inttype
# You can be length for more than one field! # You can be length for more than one field!
self.len_for: List[DynamicArrayType] = [] self.len_for: List['MessageTypeField'] = []
def is_optional(self) -> bool: def is_optional(self) -> bool:
"""This field value is always implies, never specified directly""" """This field value is always implies, never specified directly"""
return True return True
def add_length_for(self, field: 'DynamicArrayType') -> None: def add_length_for(self, field: 'MessageTypeField') -> None:
assert isinstance(field.fieldtype, DynamicArrayType) assert isinstance(field.fieldtype, DynamicArrayType)
self.len_for.append(field) self.len_for.append(field)
@ -160,7 +160,7 @@ class LengthFieldType(FieldType):
they're implied by the length of other fields""" they're implied by the length of other fields"""
return '' return ''
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None: def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]:
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)""" """We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
return self.underlying_type.read(io_in, otherfields) return self.underlying_type.read(io_in, otherfields)
@ -186,11 +186,11 @@ they're implied by the length of other fields"""
class DynamicArrayType(ArrayType): class DynamicArrayType(ArrayType):
"""This is used for arrays where another field controls the size""" """This is used for arrays where another field controls the size"""
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType): def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: 'MessageTypeField'):
super().__init__(outer, name, elemtype) super().__init__(outer, name, elemtype)
assert type(lenfield.fieldtype) is LengthFieldType assert type(lenfield.fieldtype) is LengthFieldType
self.lenfield = lenfield self.lenfield = lenfield
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
return super().read_arr(io_in, otherfields, return super().read_arr(io_in, otherfields,
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) cast(LengthFieldType, self.lenfield.fieldtype)._maybe_calc_value(self.lenfield.name, otherfields))

13
contrib/pyln-proto/pyln/proto/message/fundamental_types.py

@ -59,6 +59,15 @@ These are further specialized.
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str: def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
raise NotImplementedError() raise NotImplementedError()
def val_from_str(self, s: str) -> Tuple[Any, str]:
raise NotImplementedError()
def write(self, io_out: BufferedIOBase, v: Any, otherfields: Dict[str, Any]) -> None:
raise NotImplementedError()
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Any:
raise NotImplementedError()
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any: def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
"""Convert to a python object: for simple fields, this means a string""" """Convert to a python object: for simple fields, this means a string"""
return self.val_to_str(v, otherfields) return self.val_to_str(v, otherfields)
@ -83,7 +92,7 @@ class IntegerType(FieldType):
a, b = split_field(s) a, b = split_field(s)
return int(a), b return int(a), b
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> int: def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
"""Convert to a python object: for integer fields, this means an int""" """Convert to a python object: for integer fields, this means an int"""
return int(v) return int(v)
@ -240,7 +249,7 @@ class BigSizeType(FieldType):
return int(v) return int(v)
def fundamental_types(): def fundamental_types() -> List[FieldType]:
# From 01-messaging.md#fundamental-types: # From 01-messaging.md#fundamental-types:
return [IntegerType('byte', 1, 'B'), return [IntegerType('byte', 1, 'B'),
IntegerType('u16', 2, '>H'), IntegerType('u16', 2, '>H'),

152
contrib/pyln-proto/pyln/proto/message/message.py

@ -1,10 +1,10 @@
import struct import struct
from io import BufferedIOBase, BytesIO from io import BufferedIOBase, BytesIO
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType, IntegerType
from .array_types import ( from .array_types import (
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
) )
from typing import Dict, List, Optional, Tuple, Any, Union, cast from typing import Dict, List, Optional, Tuple, Any, Union, Callable, cast
class MessageNamespace(object): class MessageNamespace(object):
@ -12,7 +12,7 @@ class MessageNamespace(object):
domain, such as within a given BOLT""" domain, such as within a given BOLT"""
def __init__(self, csv_lines: List[str] = []): def __init__(self, csv_lines: List[str] = []):
self.subtypes: Dict[str, SubtypeType] = {} self.subtypes: Dict[str, SubtypeType] = {}
self.fundamentaltypes: Dict[str, SubtypeType] = {} self.fundamentaltypes: Dict[str, FieldType] = {}
self.tlvtypes: Dict[str, TlvStreamType] = {} self.tlvtypes: Dict[str, TlvStreamType] = {}
self.messagetypes: Dict[str, MessageType] = {} self.messagetypes: Dict[str, MessageType] = {}
@ -28,27 +28,35 @@ domain, such as within a given BOLT"""
for v in other.subtypes.values(): for v in other.subtypes.values():
ret.add_subtype(v) ret.add_subtype(v)
ret.tlvtypes = self.tlvtypes.copy() ret.tlvtypes = self.tlvtypes.copy()
for v in other.tlvtypes.values(): for tlv in other.tlvtypes.values():
ret.add_tlvtype(v) ret.add_tlvtype(tlv)
ret.messagetypes = self.messagetypes.copy() ret.messagetypes = self.messagetypes.copy()
for v in other.messagetypes.values(): for v in other.messagetypes.values():
ret.add_messagetype(v) ret.add_messagetype(v)
return ret return ret
def _check_unique(self, name: str) -> None:
"""Raise an exception if name already used"""
funtype = self.get_fundamentaltype(name)
if funtype:
raise ValueError('Already have {}'.format(funtype))
subtype = self.get_subtype(name)
if subtype:
raise ValueError('Already have {}'.format(subtype))
tlvtype = self.get_tlvtype(name)
if tlvtype:
raise ValueError('Already have {}'.format(tlvtype))
def add_subtype(self, t: 'SubtypeType') -> None: def add_subtype(self, t: 'SubtypeType') -> None:
prev = self.get_type(t.name) self._check_unique(t.name)
if prev:
raise ValueError('Already have {}'.format(prev))
self.subtypes[t.name] = t self.subtypes[t.name] = t
def add_fundamentaltype(self, t: 'SubtypeType') -> None: def add_fundamentaltype(self, t: FieldType) -> None:
assert not self.get_type(t.name) self._check_unique(t.name)
self.fundamentaltypes[t.name] = t self.fundamentaltypes[t.name] = t
def add_tlvtype(self, t: 'TlvStreamType') -> None: def add_tlvtype(self, t: 'TlvStreamType') -> None:
prev = self.get_type(t.name) self._check_unique(t.name)
if prev:
raise ValueError('Already have {}'.format(prev))
self.tlvtypes[t.name] = t self.tlvtypes[t.name] = t
def add_messagetype(self, m: 'MessageType') -> None: def add_messagetype(self, m: 'MessageType') -> None:
@ -70,7 +78,7 @@ domain, such as within a given BOLT"""
return m return m
return None return None
def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']: def get_fundamentaltype(self, name: str) -> Optional[FieldType]:
if name in self.fundamentaltypes: if name in self.fundamentaltypes:
return self.fundamentaltypes[name] return self.fundamentaltypes[name]
return None return None
@ -85,14 +93,6 @@ domain, such as within a given BOLT"""
return self.tlvtypes[name] return self.tlvtypes[name]
return None return None
def get_type(self, name: str) -> Optional['SubtypeType']:
t = self.get_fundamentaltype(name)
if t is None:
t = self.get_subtype(name)
if t is None:
t = self.get_tlvtype(name)
return t
def load_csv(self, lines: List[str]) -> None: def load_csv(self, lines: List[str]) -> None:
"""Load a series of comma-separate-value lines into the namespace""" """Load a series of comma-separate-value lines into the namespace"""
vals: Dict[str, List[List[str]]] = {'msgtype': [], vals: Dict[str, List[List[str]]] = {'msgtype': [],
@ -152,23 +152,22 @@ class MessageTypeField(object):
return self.full_name return self.full_name
class SubtypeType(object): class SubtypeType(FieldType):
"""This defines a 'subtype' in BOLT-speak. It consists of fields of """This defines a 'subtype' in BOLT-speak. It consists of fields of
other types. Since 'msgtype' and 'tlvtype' are almost identical, they other types. Since 'msgtype' is almost identical, it inherits from this too.
inherit from this too.
""" """
def __init__(self, name: str): def __init__(self, name: str):
self.name = name super().__init__(name)
self.fields: List[FieldType] = [] self.fields: List[MessageTypeField] = []
def find_field(self, fieldname: str): def find_field(self, fieldname: str) -> Optional[MessageTypeField]:
for f in self.fields: for f in self.fields:
if f.name == fieldname: if f.name == fieldname:
return f return f
return None return None
def add_field(self, field: FieldType): def add_field(self, field: MessageTypeField) -> None:
if self.find_field(field.name): if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field)) raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field) self.fields.append(field)
@ -192,12 +191,16 @@ inherit from this too.
.format(parts)) .format(parts))
return SubtypeType(parts[0]) return SubtypeType(parts[0])
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField: def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], option: str = None) -> MessageTypeField:
"""Takes msgdata/subtypedata after first two fields """Takes msgdata/subtypedata after first two fields
e.g. [...]timestamp_node_id_1,u32, e.g. [...]timestamp_node_id_1,u32,
""" """
basetype = namespace.get_type(parts[1]) basetype = namespace.get_fundamentaltype(parts[1])
if basetype is None:
basetype = namespace.get_subtype(parts[1])
if basetype is None:
basetype = namespace.get_tlvtype(parts[1])
if basetype is None: if basetype is None:
raise ValueError('Unknown type {}'.format(parts[1])) raise ValueError('Unknown type {}'.format(parts[1]))
@ -206,7 +209,8 @@ inherit from this too.
lenfield = self.find_field(parts[2]) lenfield = self.find_field(parts[2])
if lenfield is not None: if lenfield is not None:
# If we didn't know that field was a length, we do now! # If we didn't know that field was a length, we do now!
if type(lenfield.fieldtype) is not LengthFieldType: if not isinstance(lenfield.fieldtype, LengthFieldType):
assert isinstance(lenfield.fieldtype, IntegerType)
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype) lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
field = MessageTypeField(self.name, parts[0], field = MessageTypeField(self.name, parts[0],
DynamicArrayType(self, DynamicArrayType(self,
@ -215,7 +219,9 @@ inherit from this too.
lenfield), lenfield),
option) option)
lenfield.fieldtype.add_length_for(field) lenfield.fieldtype.add_length_for(field)
elif ellipsisok and parts[2] == '...': elif parts[2] == '...':
# ... is only valid for a TLV.
assert isinstance(self, TlvMessageType)
field = MessageTypeField(self.name, parts[0], field = MessageTypeField(self.name, parts[0],
EllipsisArrayType(self, EllipsisArrayType(self,
parts[0], basetype), parts[0], basetype),
@ -264,8 +270,10 @@ inherit from this too.
raise ValueError("Unknown fields specified: {}".format(unknown)) raise ValueError("Unknown fields specified: {}".format(unknown))
for f in defined.difference(have): for f in defined.difference(have):
if not f.fieldtype.is_optional(): field = self.find_field(f)
raise ValueError("Missing value for {}".format(f)) assert field
if not field.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(field))
def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str: def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
self._raise_if_badvals(v) self._raise_if_badvals(v)
@ -273,6 +281,7 @@ inherit from this too.
sep = '' sep = ''
for fname, val in v.items(): for fname, val in v.items():
field = self.find_field(fname) field = self.find_field(fname)
assert field
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields) s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
sep = ',' sep = ','
@ -281,16 +290,19 @@ inherit from this too.
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]: def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
ret: Dict[str, Any] = {} ret: Dict[str, Any] = {}
for k, v in val.items(): for k, v in val.items():
ret[k] = self.find_field(k).fieldtype.val_to_py(v, val) field = self.find_field(k)
assert field
ret[k] = field.fieldtype.val_to_py(v, val)
return ret return ret
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None: def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
self._raise_if_badvals(v) self._raise_if_badvals(v)
for fname, val in v.items(): for fname, val in v.items():
field = self.find_field(fname) field = self.find_field(fname)
assert field
field.fieldtype.write(io_out, val, otherfields) field.fieldtype.write(io_out, val, otherfields)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]: def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
vals = {} vals = {}
for field in self.fields: for field in self.fields:
val = field.fieldtype.read(io_in, otherfields) val = field.fieldtype.read(io_in, otherfields)
@ -383,25 +395,46 @@ class MessageType(SubtypeType):
messagetype.add_field(field) messagetype.add_field(field)
class TlvStreamType(SubtypeType): class TlvMessageType(MessageType):
"""A TlvStreamType is just a Subtype, but its fields are """A 'tlvtype' in BOLT-speak"""
TlvMessageTypes. In the CSV format these are created implicitly, when
a tlvtype line (which defines a TlvMessageType within the TlvType, def __init__(self, name: str, value: str):
confusingly) refers to them. super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class TlvStreamType(FieldType):
"""A TlvStreamType's fields are TlvMessageTypes. In the CSV format
these are created implicitly, when a tlvtype line (which defines a
TlvMessageType within the TlvType, confusingly) refers to them.
""" """
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
self.fields: List[TlvMessageType] = []
def __str__(self): def __str__(self):
return "tlvstreamtype-{}".format(self.name) return "tlvstreamtype-{}".format(self.name)
def find_field_by_number(self, num: int) -> Optional['TlvMessageType']: def find_field(self, fieldname: str) -> Optional[TlvMessageType]:
for f in self.fields:
if f.name == fieldname:
return f
return None
def find_field_by_number(self, num: int) -> Optional[TlvMessageType]:
for f in self.fields: for f in self.fields:
if f.number == num: if f.number == num:
return f return f
return None return None
def add_field(self, field: TlvMessageType) -> None:
if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field)
def is_optional(self) -> bool: def is_optional(self) -> bool:
"""You can omit a tlvstream= altogether""" """You can omit a tlvstream= altogether"""
return True return True
@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
raise ValueError("Unknown tlv field {}.{}" raise ValueError("Unknown tlv field {}.{}"
.format(tlvstream, parts[1])) .format(tlvstream, parts[1]))
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True) subfield = field._field_from_csv(namespace, parts[2:])
field.add_field(subfield) field.add_field(subfield)
def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]: def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]: def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
ret: Dict[str, Any] = {} ret: Dict[str, Any] = {}
for k, v in val.items(): for k, v in val.items():
ret[k] = self.find_field(k).val_to_py(v, val) field = self.find_field(k)
assert field
ret[k] = field.val_to_py(v, val)
return ret return ret
def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None: def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into # Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires. # ascending order as TLV spec requires.
def write_raw_val(iobuf, val, otherfields: Dict[str, Any]): def write_raw_val(iobuf: BufferedIOBase, val: Any, otherfields: Dict[str, Any]) -> None:
iobuf.write(val) iobuf.write(val)
def get_value(tup): def get_value(tup):
"""Get value from num, fun, val tuple""" """Get value from num, fun, val tuple"""
return tup[0] return tup[0]
ordered = [] ordered: List[Tuple[int,
Callable[[BufferedIOBase, Any, Dict[str, Any]], None],
Any]] = []
for fieldname in v: for fieldname in v:
f = self.find_field(fieldname) f = self.find_field(fieldname)
if f is None: if f is None:
@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
for typenum, writefunc, val in ordered: for typenum, writefunc, val in ordered:
buf = BytesIO() buf = BytesIO()
writefunc(buf, val, otherfields) writefunc(cast(BufferedIOBase, buf), val, otherfields)
BigSizeType.write(io_out, typenum) BigSizeType.write(io_out, typenum)
BigSizeType.write(io_out, len(buf.getvalue())) BigSizeType.write(io_out, len(buf.getvalue()))
io_out.write(buf.getvalue()) io_out.write(buf.getvalue())
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]: def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[Union[str, int], Any]:
vals: Dict[str, Any] = {} vals: Dict[Union[str, int], Any] = {}
while True: while True:
tlv_type = BigSizeType.read(io_in) tlv_type = BigSizeType.read(io_in)
@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return " {}={}".format(name, self.val_to_str(v, {})) return " {}={}".format(name, self.val_to_str(v, {}))
class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak"""
def __init__(self, name: str, value: str):
super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class Message(object): class Message(object):
"""A particular message instance""" """A particular message instance"""
def __init__(self, messagetype: MessageType, **kwargs): def __init__(self, messagetype: MessageType, **kwargs):
@ -679,7 +706,8 @@ Must not have missing fields.
"""Convert to a Python native object: dicts, lists, strings, ints""" """Convert to a Python native object: dicts, lists, strings, ints"""
ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {} ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {}
for f, v in self.fields.items(): for f, v in self.fields.items():
fieldtype = self.messagetype.find_field(f).fieldtype field = self.messagetype.find_field(f)
ret[f] = fieldtype.val_to_py(v, self.fields) assert field
ret[f] = field.fieldtype.val_to_py(v, self.fields)
return ret return ret

Loading…
Cancel
Save