diff --git a/contrib/pyln-proto/Makefile b/contrib/pyln-proto/Makefile index c111cd6ef..81edb3985 100644 --- a/contrib/pyln-proto/Makefile +++ b/contrib/pyln-proto/Makefile @@ -17,7 +17,7 @@ check-flake8: # mypy . does not recurse. I have no idea why... check-mypy: - mypy --ignore-missing-imports `find * -name '*.py'` + mypy --ignore-missing-imports `find pyln/proto/message/ -name '*.py'` $(SDIST_FILE): python3 setup.py sdist diff --git a/contrib/pyln-proto/pyln/proto/message/array_types.py b/contrib/pyln-proto/pyln/proto/message/array_types.py index ea8616d2d..60c7011da 100644 --- a/contrib/pyln-proto/pyln/proto/message/array_types.py +++ b/contrib/pyln-proto/pyln/proto/message/array_types.py @@ -1,8 +1,8 @@ from .fundamental_types import FieldType, IntegerType, split_field -from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union +from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union, cast from io import BufferedIOBase if TYPE_CHECKING: - from .message import SubtypeType, TlvStreamType + from .message import SubtypeType, TlvMessageType, MessageTypeField class ArrayType(FieldType): @@ -98,7 +98,7 @@ class SizedArrayType(ArrayType): class EllipsisArrayType(ArrayType): """This is used for ... fields at the end of a tlv: the array ends when the tlv ends""" - def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType): + def __init__(self, tlv: 'TlvMessageType', name: str, elemtype: FieldType): super().__init__(tlv, name, elemtype) def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: @@ -119,13 +119,13 @@ class LengthFieldType(FieldType): super().__init__(inttype.name) self.underlying_type = inttype # You can be length for more than one field! - self.len_for: List[DynamicArrayType] = [] + self.len_for: List['MessageTypeField'] = [] def is_optional(self) -> bool: """This field value is always implies, never specified directly""" return True - def add_length_for(self, field: 'DynamicArrayType') -> None: + def add_length_for(self, field: 'MessageTypeField') -> None: assert isinstance(field.fieldtype, DynamicArrayType) self.len_for.append(field) @@ -160,7 +160,7 @@ class LengthFieldType(FieldType): they're implied by the length of other fields""" return '' - def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None: + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]: """We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)""" return self.underlying_type.read(io_in, otherfields) @@ -186,11 +186,11 @@ they're implied by the length of other fields""" class DynamicArrayType(ArrayType): """This is used for arrays where another field controls the size""" - def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType): + def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: 'MessageTypeField'): super().__init__(outer, name, elemtype) assert type(lenfield.fieldtype) is LengthFieldType self.lenfield = lenfield def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: return super().read_arr(io_in, otherfields, - self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) + cast(LengthFieldType, self.lenfield.fieldtype)._maybe_calc_value(self.lenfield.name, otherfields)) diff --git a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py index 8341a5c90..11f26704a 100644 --- a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py +++ b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py @@ -59,6 +59,15 @@ These are further specialized. def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str: raise NotImplementedError() + def val_from_str(self, s: str) -> Tuple[Any, str]: + raise NotImplementedError() + + def write(self, io_out: BufferedIOBase, v: Any, otherfields: Dict[str, Any]) -> None: + raise NotImplementedError() + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Any: + raise NotImplementedError() + def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any: """Convert to a python object: for simple fields, this means a string""" return self.val_to_str(v, otherfields) @@ -83,7 +92,7 @@ class IntegerType(FieldType): a, b = split_field(s) return int(a), b - def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> int: + def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any: """Convert to a python object: for integer fields, this means an int""" return int(v) @@ -240,7 +249,7 @@ class BigSizeType(FieldType): return int(v) -def fundamental_types(): +def fundamental_types() -> List[FieldType]: # From 01-messaging.md#fundamental-types: return [IntegerType('byte', 1, 'B'), IntegerType('u16', 2, '>H'), diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py index d8b5d848d..33b879c9b 100644 --- a/contrib/pyln-proto/pyln/proto/message/message.py +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -1,10 +1,10 @@ import struct from io import BufferedIOBase, BytesIO -from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType +from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType, IntegerType from .array_types import ( SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType ) -from typing import Dict, List, Optional, Tuple, Any, Union, cast +from typing import Dict, List, Optional, Tuple, Any, Union, Callable, cast class MessageNamespace(object): @@ -12,7 +12,7 @@ class MessageNamespace(object): domain, such as within a given BOLT""" def __init__(self, csv_lines: List[str] = []): self.subtypes: Dict[str, SubtypeType] = {} - self.fundamentaltypes: Dict[str, SubtypeType] = {} + self.fundamentaltypes: Dict[str, FieldType] = {} self.tlvtypes: Dict[str, TlvStreamType] = {} self.messagetypes: Dict[str, MessageType] = {} @@ -28,27 +28,35 @@ domain, such as within a given BOLT""" for v in other.subtypes.values(): ret.add_subtype(v) ret.tlvtypes = self.tlvtypes.copy() - for v in other.tlvtypes.values(): - ret.add_tlvtype(v) + for tlv in other.tlvtypes.values(): + ret.add_tlvtype(tlv) ret.messagetypes = self.messagetypes.copy() for v in other.messagetypes.values(): ret.add_messagetype(v) return ret + def _check_unique(self, name: str) -> None: + """Raise an exception if name already used""" + funtype = self.get_fundamentaltype(name) + if funtype: + raise ValueError('Already have {}'.format(funtype)) + subtype = self.get_subtype(name) + if subtype: + raise ValueError('Already have {}'.format(subtype)) + tlvtype = self.get_tlvtype(name) + if tlvtype: + raise ValueError('Already have {}'.format(tlvtype)) + def add_subtype(self, t: 'SubtypeType') -> None: - prev = self.get_type(t.name) - if prev: - raise ValueError('Already have {}'.format(prev)) + self._check_unique(t.name) self.subtypes[t.name] = t - def add_fundamentaltype(self, t: 'SubtypeType') -> None: - assert not self.get_type(t.name) + def add_fundamentaltype(self, t: FieldType) -> None: + self._check_unique(t.name) self.fundamentaltypes[t.name] = t def add_tlvtype(self, t: 'TlvStreamType') -> None: - prev = self.get_type(t.name) - if prev: - raise ValueError('Already have {}'.format(prev)) + self._check_unique(t.name) self.tlvtypes[t.name] = t def add_messagetype(self, m: 'MessageType') -> None: @@ -70,7 +78,7 @@ domain, such as within a given BOLT""" return m return None - def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']: + def get_fundamentaltype(self, name: str) -> Optional[FieldType]: if name in self.fundamentaltypes: return self.fundamentaltypes[name] return None @@ -85,14 +93,6 @@ domain, such as within a given BOLT""" return self.tlvtypes[name] return None - def get_type(self, name: str) -> Optional['SubtypeType']: - t = self.get_fundamentaltype(name) - if t is None: - t = self.get_subtype(name) - if t is None: - t = self.get_tlvtype(name) - return t - def load_csv(self, lines: List[str]) -> None: """Load a series of comma-separate-value lines into the namespace""" vals: Dict[str, List[List[str]]] = {'msgtype': [], @@ -152,23 +152,22 @@ class MessageTypeField(object): return self.full_name -class SubtypeType(object): +class SubtypeType(FieldType): """This defines a 'subtype' in BOLT-speak. It consists of fields of -other types. Since 'msgtype' and 'tlvtype' are almost identical, they -inherit from this too. +other types. Since 'msgtype' is almost identical, it inherits from this too. """ def __init__(self, name: str): - self.name = name - self.fields: List[FieldType] = [] + super().__init__(name) + self.fields: List[MessageTypeField] = [] - def find_field(self, fieldname: str): + def find_field(self, fieldname: str) -> Optional[MessageTypeField]: for f in self.fields: if f.name == fieldname: return f return None - def add_field(self, field: FieldType): + def add_field(self, field: MessageTypeField) -> None: if self.find_field(field.name): raise ValueError("{}: duplicate field {}".format(self, field)) self.fields.append(field) @@ -192,12 +191,16 @@ inherit from this too. .format(parts)) return SubtypeType(parts[0]) - def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField: + def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], option: str = None) -> MessageTypeField: """Takes msgdata/subtypedata after first two fields e.g. [...]timestamp_node_id_1,u32, """ - basetype = namespace.get_type(parts[1]) + basetype = namespace.get_fundamentaltype(parts[1]) + if basetype is None: + basetype = namespace.get_subtype(parts[1]) + if basetype is None: + basetype = namespace.get_tlvtype(parts[1]) if basetype is None: raise ValueError('Unknown type {}'.format(parts[1])) @@ -206,7 +209,8 @@ inherit from this too. lenfield = self.find_field(parts[2]) if lenfield is not None: # If we didn't know that field was a length, we do now! - if type(lenfield.fieldtype) is not LengthFieldType: + if not isinstance(lenfield.fieldtype, LengthFieldType): + assert isinstance(lenfield.fieldtype, IntegerType) lenfield.fieldtype = LengthFieldType(lenfield.fieldtype) field = MessageTypeField(self.name, parts[0], DynamicArrayType(self, @@ -215,7 +219,9 @@ inherit from this too. lenfield), option) lenfield.fieldtype.add_length_for(field) - elif ellipsisok and parts[2] == '...': + elif parts[2] == '...': + # ... is only valid for a TLV. + assert isinstance(self, TlvMessageType) field = MessageTypeField(self.name, parts[0], EllipsisArrayType(self, parts[0], basetype), @@ -264,8 +270,10 @@ inherit from this too. raise ValueError("Unknown fields specified: {}".format(unknown)) for f in defined.difference(have): - if not f.fieldtype.is_optional(): - raise ValueError("Missing value for {}".format(f)) + field = self.find_field(f) + assert field + if not field.fieldtype.is_optional(): + raise ValueError("Missing value for {}".format(field)) def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str: self._raise_if_badvals(v) @@ -273,6 +281,7 @@ inherit from this too. sep = '' for fname, val in v.items(): field = self.find_field(fname) + assert field s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields) sep = ',' @@ -281,16 +290,19 @@ inherit from this too. def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]: ret: Dict[str, Any] = {} for k, v in val.items(): - ret[k] = self.find_field(k).fieldtype.val_to_py(v, val) + field = self.find_field(k) + assert field + ret[k] = field.fieldtype.val_to_py(v, val) return ret def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None: self._raise_if_badvals(v) for fname, val in v.items(): field = self.find_field(fname) + assert field field.fieldtype.write(io_out, val, otherfields) - def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]: + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]: vals = {} for field in self.fields: val = field.fieldtype.read(io_in, otherfields) @@ -383,25 +395,46 @@ class MessageType(SubtypeType): messagetype.add_field(field) -class TlvStreamType(SubtypeType): - """A TlvStreamType is just a Subtype, but its fields are -TlvMessageTypes. In the CSV format these are created implicitly, when -a tlvtype line (which defines a TlvMessageType within the TlvType, -confusingly) refers to them. +class TlvMessageType(MessageType): + """A 'tlvtype' in BOLT-speak""" + + def __init__(self, name: str, value: str): + super().__init__(name, value) + + def __str__(self): + return "tlvmsgtype-{}".format(self.name) + + +class TlvStreamType(FieldType): + """A TlvStreamType's fields are TlvMessageTypes. In the CSV format +these are created implicitly, when a tlvtype line (which defines a +TlvMessageType within the TlvType, confusingly) refers to them. """ def __init__(self, name): super().__init__(name) + self.fields: List[TlvMessageType] = [] def __str__(self): return "tlvstreamtype-{}".format(self.name) - def find_field_by_number(self, num: int) -> Optional['TlvMessageType']: + def find_field(self, fieldname: str) -> Optional[TlvMessageType]: + for f in self.fields: + if f.name == fieldname: + return f + return None + + def find_field_by_number(self, num: int) -> Optional[TlvMessageType]: for f in self.fields: if f.number == num: return f return None + def add_field(self, field: TlvMessageType) -> None: + if self.find_field(field.name): + raise ValueError("{}: duplicate field {}".format(self, field)) + self.fields.append(field) + def is_optional(self) -> bool: """You can omit a tlvstream= altogether""" return True @@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, raise ValueError("Unknown tlv field {}.{}" .format(tlvstream, parts[1])) - subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True) + subfield = field._field_from_csv(namespace, parts[2:]) field.add_field(subfield) def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]: @@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]: ret: Dict[str, Any] = {} for k, v in val.items(): - ret[k] = self.find_field(k).val_to_py(v, val) + field = self.find_field(k) + assert field + ret[k] = field.val_to_py(v, val) return ret def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None: @@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, # Make a tuple of (fieldnum, val_to_bin, val) so we can sort into # ascending order as TLV spec requires. - def write_raw_val(iobuf, val, otherfields: Dict[str, Any]): + def write_raw_val(iobuf: BufferedIOBase, val: Any, otherfields: Dict[str, Any]) -> None: iobuf.write(val) def get_value(tup): """Get value from num, fun, val tuple""" return tup[0] - ordered = [] + ordered: List[Tuple[int, + Callable[[BufferedIOBase, Any, Dict[str, Any]], None], + Any]] = [] for fieldname in v: f = self.find_field(fieldname) if f is None: @@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, for typenum, writefunc, val in ordered: buf = BytesIO() - writefunc(buf, val, otherfields) + writefunc(cast(BufferedIOBase, buf), val, otherfields) BigSizeType.write(io_out, typenum) BigSizeType.write(io_out, len(buf.getvalue())) io_out.write(buf.getvalue()) - def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]: - vals: Dict[str, Any] = {} + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[Union[str, int], Any]: + vals: Dict[Union[str, int], Any] = {} while True: tlv_type = BigSizeType.read(io_in) @@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, return " {}={}".format(name, self.val_to_str(v, {})) -class TlvMessageType(MessageType): - """A 'tlvtype' in BOLT-speak""" - - def __init__(self, name: str, value: str): - super().__init__(name, value) - - def __str__(self): - return "tlvmsgtype-{}".format(self.name) - - class Message(object): """A particular message instance""" def __init__(self, messagetype: MessageType, **kwargs): @@ -679,7 +706,8 @@ Must not have missing fields. """Convert to a Python native object: dicts, lists, strings, ints""" ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {} for f, v in self.fields.items(): - fieldtype = self.messagetype.find_field(f).fieldtype - ret[f] = fieldtype.val_to_py(v, self.fields) + field = self.messagetype.find_field(f) + assert field + ret[f] = field.fieldtype.val_to_py(v, self.fields) return ret