Browse Source

pyln.proto.message.*: add type annotations.

Other changes along the way:

1. In a couple of places we passed None as a dummy for for
   `otherfields` where {} is just as good.
2. Turned bytes into hex for errors.
3. Remove nonsensical (unused) get_tlv_by_number() function from MessageNamespace
4. Renamed unrelated-but-overlapping `field_from_csv` and
   `type_from_csv` static methods, since mypy thought they should have
   the same type.
5. Unknown tlv fields are placed in dict as strings, not ints, for
   type simplicity.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
nifty/pset-pre
Rusty Russell 5 years ago
committed by Christian Decker
parent
commit
f52065201b
  1. 1
      .gitignore
  2. 60
      contrib/pyln-proto/pyln/proto/message/array_types.py
  3. 71
      contrib/pyln-proto/pyln/proto/message/fundamental_types.py
  4. 175
      contrib/pyln-proto/pyln/proto/message/message.py

1
.gitignore

@ -11,6 +11,7 @@
*.po *.po
*.pyc *.pyc
.cppcheck-suppress .cppcheck-suppress
.mypy_cache
TAGS TAGS
tags tags
ccan/tools/configurator/configurator ccan/tools/configurator/configurator

60
contrib/pyln-proto/pyln/proto/message/array_types.py

@ -1,4 +1,8 @@
from .fundamental_types import FieldType, IntegerType, split_field from .fundamental_types import FieldType, IntegerType, split_field
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any
from io import BufferedIOBase
if TYPE_CHECKING:
from .message import SubtypeType, TlvStreamType
class ArrayType(FieldType): class ArrayType(FieldType):
@ -8,11 +12,11 @@ These are not in the namespace, but generated when a message says it
wants an array of some type. wants an array of some type.
""" """
def __init__(self, outer, name, elemtype): def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType):
super().__init__("{}.{}".format(outer.name, name)) super().__init__("{}.{}".format(outer.name, name))
self.elemtype = elemtype self.elemtype = elemtype
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[List[Any], str]:
# Simple arrays of bytes don't need commas # Simple arrays of bytes don't need commas
if self.elemtype.name == 'byte': if self.elemtype.name == 'byte':
a, b = split_field(s) a, b = split_field(s)
@ -30,20 +34,20 @@ wants an array of some type.
s = s[1:] s = s[1:]
return ret, s[1:] return ret, s[1:]
def val_to_str(self, v, otherfields): def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str:
if self.elemtype.name == 'byte': if self.elemtype.name == 'byte':
return bytes(v).hex() return bytes(v).hex()
s = ','.join(self.elemtype.val_to_str(i, otherfields) for i in v) s = ','.join(self.elemtype.val_to_str(i, otherfields) for i in v)
return '[' + s + ']' return '[' + s + ']'
def write(self, io_out, v, otherfields): def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
for i in v: for i in v:
self.elemtype.write(io_out, i, otherfields) self.elemtype.write(io_out, i, otherfields)
def read_arr(self, io_in, otherfields, arraysize): def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]:
"""arraysize None means take rest of io entirely and exactly""" """arraysize None means take rest of io entirely and exactly"""
vals = [] vals: List[Any] = []
while arraysize is None or len(vals) < arraysize: while arraysize is None or len(vals) < arraysize:
# Throws an exception on partial read, so None means completely empty. # Throws an exception on partial read, so None means completely empty.
val = self.elemtype.read(io_in, otherfields) val = self.elemtype.read(io_in, otherfields)
@ -60,65 +64,65 @@ wants an array of some type.
class SizedArrayType(ArrayType): class SizedArrayType(ArrayType):
"""A fixed-size array""" """A fixed-size array"""
def __init__(self, outer, name, elemtype, arraysize): def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, arraysize: int):
super().__init__(outer, name, elemtype) super().__init__(outer, name, elemtype)
self.arraysize = arraysize self.arraysize = arraysize
def val_to_str(self, v, otherfields): def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str:
if len(v) != self.arraysize: if len(v) != self.arraysize:
raise ValueError("Length of {} != {}", v, self.arraysize) raise ValueError("Length of {} != {}", v, self.arraysize)
return super().val_to_str(v, otherfields) return super().val_to_str(v, otherfields)
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[List[Any], str]:
a, b = super().val_from_str(s) a, b = super().val_from_str(s)
if len(a) != self.arraysize: if len(a) != self.arraysize:
raise ValueError("Length of {} != {}", s, self.arraysize) raise ValueError("Length of {} != {}", s, self.arraysize)
return a, b return a, b
def write(self, io_out, v, otherfields): def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
if len(v) != self.arraysize: if len(v) != self.arraysize:
raise ValueError("Length of {} != {}", v, self.arraysize) raise ValueError("Length of {} != {}", v, self.arraysize)
return super().write(io_out, v, otherfields) return super().write(io_out, v, otherfields)
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
return super().read_arr(io_in, otherfields, self.arraysize) return super().read_arr(io_in, otherfields, self.arraysize)
class EllipsisArrayType(ArrayType): class EllipsisArrayType(ArrayType):
"""This is used for ... fields at the end of a tlv: the array ends """This is used for ... fields at the end of a tlv: the array ends
when the tlv ends""" when the tlv ends"""
def __init__(self, tlv, name, elemtype): def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType):
super().__init__(tlv, name, elemtype) super().__init__(tlv, name, elemtype)
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
"""Takes rest of bytestream""" """Takes rest of bytestream"""
return super().read_arr(io_in, otherfields, None) return super().read_arr(io_in, otherfields, None)
def only_at_tlv_end(self): def only_at_tlv_end(self) -> bool:
"""These only make sense at the end of a TLV""" """These only make sense at the end of a TLV"""
return True return True
class LengthFieldType(FieldType): class LengthFieldType(FieldType):
"""Special type to indicate this serves as a length field for others""" """Special type to indicate this serves as a length field for others"""
def __init__(self, inttype): def __init__(self, inttype: IntegerType):
if type(inttype) is not IntegerType: if type(inttype) is not IntegerType:
raise ValueError("{} cannot be a length; not an integer!" raise ValueError("{} cannot be a length; not an integer!"
.format(self.name)) .format(self.name))
super().__init__(inttype.name) super().__init__(inttype.name)
self.underlying_type = inttype self.underlying_type = inttype
# You can be length for more than one field! # You can be length for more than one field!
self.len_for = [] self.len_for: List[DynamicArrayType] = []
def is_optional(self): def is_optional(self) -> bool:
"""This field value is always implies, never specified directly""" """This field value is always implies, never specified directly"""
return True return True
def add_length_for(self, field): def add_length_for(self, field: 'DynamicArrayType') -> None:
assert isinstance(field.fieldtype, DynamicArrayType) assert isinstance(field.fieldtype, DynamicArrayType)
self.len_for.append(field) self.len_for.append(field)
def calc_value(self, otherfields): def calc_value(self, otherfields: Dict[str, Any]) -> int:
"""Calculate length value from field(s) themselves""" """Calculate length value from field(s) themselves"""
if self.len_fields_bad('', otherfields): if self.len_fields_bad('', otherfields):
raise ValueError("Lengths of fields {} not equal!" raise ValueError("Lengths of fields {} not equal!"
@ -126,7 +130,7 @@ class LengthFieldType(FieldType):
return len(otherfields[self.len_for[0].name]) return len(otherfields[self.len_for[0].name])
def _maybe_calc_value(self, fieldname, otherfields): def _maybe_calc_value(self, fieldname: str, otherfields: Dict[str, Any]) -> int:
# Perhaps we're just demarshalling from binary now, so we actually # Perhaps we're just demarshalling from binary now, so we actually
# stored it. Remove, and we'll calc from now on. # stored it. Remove, and we'll calc from now on.
if fieldname in otherfields: if fieldname in otherfields:
@ -135,27 +139,27 @@ class LengthFieldType(FieldType):
return v return v
return self.calc_value(otherfields) return self.calc_value(otherfields)
def val_to_str(self, _, otherfields): def val_to_str(self, _, otherfields: Dict[str, Any]) -> str:
return self.underlying_type.val_to_str(self.calc_value(otherfields), return self.underlying_type.val_to_str(self.calc_value(otherfields),
otherfields) otherfields)
def name_and_val(self, name, v): def name_and_val(self, name: str, v: int) -> str:
"""We don't print out length fields when printing out messages: """We don't print out length fields when printing out messages:
they're implied by the length of other fields""" they're implied by the length of other fields"""
return '' return ''
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)""" """We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
return self.underlying_type.read(io_in, otherfields) return self.underlying_type.read(io_in, otherfields)
def write(self, io_out, _, otherfields): def write(self, io_out: BufferedIOBase, _, otherfields: Dict[str, Any]) -> None:
self.underlying_type.write(io_out, self.calc_value(otherfields), self.underlying_type.write(io_out, self.calc_value(otherfields),
otherfields) otherfields)
def val_from_str(self, s): def val_from_str(self, s: str):
raise ValueError('{} is implied, cannot be specified'.format(self)) raise ValueError('{} is implied, cannot be specified'.format(self))
def len_fields_bad(self, fieldname, otherfields): def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]:
"""fieldname is the name to return if this length is bad""" """fieldname is the name to return if this length is bad"""
mylen = None mylen = None
for lens in self.len_for: for lens in self.len_for:
@ -170,11 +174,11 @@ they're implied by the length of other fields"""
class DynamicArrayType(ArrayType): class DynamicArrayType(ArrayType):
"""This is used for arrays where another field controls the size""" """This is used for arrays where another field controls the size"""
def __init__(self, outer, name, elemtype, lenfield): def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType):
super().__init__(outer, name, elemtype) super().__init__(outer, name, elemtype)
assert type(lenfield.fieldtype) is LengthFieldType assert type(lenfield.fieldtype) is LengthFieldType
self.lenfield = lenfield self.lenfield = lenfield
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
return super().read_arr(io_in, otherfields, return super().read_arr(io_in, otherfields,
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))

71
contrib/pyln-proto/pyln/proto/message/fundamental_types.py

@ -1,11 +1,11 @@
import struct import struct
import io from io import BufferedIOBase
import sys import sys
from typing import Optional from typing import Dict, Optional, Tuple, List, Any
def try_unpack(name: str, def try_unpack(name: str,
io_out: io.BufferedIOBase, io_out: BufferedIOBase,
structfmt: str, structfmt: str,
empty_ok: bool) -> Optional[int]: empty_ok: bool) -> Optional[int]:
"""Unpack a single value using struct.unpack. """Unpack a single value using struct.unpack.
@ -20,7 +20,7 @@ If need_all, never return None, otherwise returns None if EOF."""
return struct.unpack(structfmt, b)[0] return struct.unpack(structfmt, b)[0]
def split_field(s): def split_field(s: str) -> Tuple[str, str]:
"""Helper to split string into first part and remainder""" """Helper to split string into first part and remainder"""
def len_without(s, delim): def len_without(s, delim):
pos = s.find(delim) pos = s.find(delim)
@ -37,25 +37,28 @@ class FieldType(object):
These are further specialized. These are further specialized.
""" """
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name = name
def only_at_tlv_end(self): def only_at_tlv_end(self) -> bool:
"""Some types only make sense inside a tlv, at the end""" """Some types only make sense inside a tlv, at the end"""
return False return False
def name_and_val(self, name, v): def name_and_val(self, name: str, v: Any) -> str:
"""This is overridden by LengthFieldType to return nothing""" """This is overridden by LengthFieldType to return nothing"""
return " {}={}".format(name, self.val_to_str(v, None)) return " {}={}".format(name, self.val_to_str(v, {}))
def is_optional(self): def is_optional(self) -> bool:
"""Overridden for tlv fields and optional fields""" """Overridden for tlv fields and optional fields"""
return False return False
def len_fields_bad(self, fieldname, fieldvals): def len_fields_bad(self, fieldname: str, fieldvals: Dict[str, Any]) -> List[str]:
"""Overridden by length fields for arrays""" """Overridden by length fields for arrays"""
return [] return []
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
raise NotImplementedError()
def __str__(self): def __str__(self):
return self.name return self.name
@ -64,22 +67,22 @@ These are further specialized.
class IntegerType(FieldType): class IntegerType(FieldType):
def __init__(self, name, bytelen, structfmt): def __init__(self, name: str, bytelen: int, structfmt: str):
super().__init__(name) super().__init__(name)
self.bytelen = bytelen self.bytelen = bytelen
self.structfmt = structfmt self.structfmt = structfmt
def val_to_str(self, v, otherfields): def val_to_str(self, v: int, otherfields: Dict[str, Any]):
return "{}".format(int(v)) return "{}".format(int(v))
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[int, str]:
a, b = split_field(s) a, b = split_field(s)
return int(a), b return int(a), b
def write(self, io_out, v, otherfields): def write(self, io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any]) -> None:
io_out.write(struct.pack(self.structfmt, v)) io_out.write(struct.pack(self.structfmt, v))
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]:
return try_unpack(self.name, io_in, self.structfmt, empty_ok=True) return try_unpack(self.name, io_in, self.structfmt, empty_ok=True)
@ -91,11 +94,11 @@ basically a u64.
def __init__(self, name): def __init__(self, name):
super().__init__(name, 8, '>Q') super().__init__(name, 8, '>Q')
def val_to_str(self, v, otherfields): def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str:
# See BOLT #7: ## Definition of `short_channel_id` # See BOLT #7: ## Definition of `short_channel_id`
return "{}x{}x{}".format(v >> 40, (v >> 16) & 0xFFFFFF, v & 0xFFFF) return "{}x{}x{}".format(v >> 40, (v >> 16) & 0xFFFFFF, v & 0xFFFF)
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[int, str]:
a, b = split_field(s) a, b = split_field(s)
parts = a.split('x') parts = a.split('x')
if len(parts) != 3: if len(parts) != 3:
@ -107,25 +110,25 @@ basically a u64.
class TruncatedIntType(FieldType): class TruncatedIntType(FieldType):
"""Truncated integer types""" """Truncated integer types"""
def __init__(self, name, maxbytes): def __init__(self, name: str, maxbytes: int):
super().__init__(name) super().__init__(name)
self.maxbytes = maxbytes self.maxbytes = maxbytes
def val_to_str(self, v, otherfields): def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str:
return "{}".format(int(v)) return "{}".format(int(v))
def only_at_tlv_end(self): def only_at_tlv_end(self) -> bool:
"""These only make sense at the end of a TLV""" """These only make sense at the end of a TLV"""
return True return True
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[int, str]:
a, b = split_field(s) a, b = split_field(s)
if int(a) >= (1 << (self.maxbytes * 8)): if int(a) >= (1 << (self.maxbytes * 8)):
raise ValueError('{} exceeds maximum {} capacity' raise ValueError('{} exceeds maximum {} capacity'
.format(a, self.name)) .format(a, self.name))
return int(a), b return int(a), b
def write(self, io_out, v, otherfields): def write(self, io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any]) -> None:
binval = struct.pack('>Q', v) binval = struct.pack('>Q', v)
while len(binval) != 0 and binval[0] == 0: while len(binval) != 0 and binval[0] == 0:
binval = binval[1:] binval = binval[1:]
@ -134,41 +137,41 @@ class TruncatedIntType(FieldType):
.format(v, self.name)) .format(v, self.name))
io_out.write(binval) io_out.write(binval)
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
binval = io_in.read() binval = io_in.read()
if len(binval) > self.maxbytes: if len(binval) > self.maxbytes:
raise ValueError('{} is too long for {}'.format(binval, self.name)) raise ValueError('{} is too long for {}'.format(binval.hex(), self.name))
if len(binval) > 0 and binval[0] == 0: if len(binval) > 0 and binval[0] == 0:
raise ValueError('{} encoding is not minimal: {}' raise ValueError('{} encoding is not minimal: {}'
.format(self.name, binval)) .format(self.name, binval.hex()))
# Pad with zeroes and convert as u64 # Pad with zeroes and convert as u64
return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0] return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0]
class FundamentalHexType(FieldType): class FundamentalHexType(FieldType):
"""The remaining fundamental types are simply represented as hex strings""" """The remaining fundamental types are simply represented as hex strings"""
def __init__(self, name, bytelen): def __init__(self, name: str, bytelen: int):
super().__init__(name) super().__init__(name)
self.bytelen = bytelen self.bytelen = bytelen
def val_to_str(self, v, otherfields): def val_to_str(self, v: bytes, otherfields: Dict[str, Any]) -> str:
if len(bytes(v)) != self.bytelen: if len(bytes(v)) != self.bytelen:
raise ValueError("Length of {} != {}", v, self.bytelen) raise ValueError("Length of {} != {}", v, self.bytelen)
return v.hex() return v.hex()
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[bytes, str]:
a, b = split_field(s) a, b = split_field(s)
ret = bytes.fromhex(a) ret = bytes.fromhex(a)
if len(ret) != self.bytelen: if len(ret) != self.bytelen:
raise ValueError("Length of {} != {}", a, self.bytelen) raise ValueError("Length of {} != {}", a, self.bytelen)
return ret, b return ret, b
def write(self, io_out, v, otherfields): def write(self, io_out: BufferedIOBase, v: bytes, otherfields: Dict[str, Any]) -> None:
if len(bytes(v)) != self.bytelen: if len(bytes(v)) != self.bytelen:
raise ValueError("Length of {} != {}", v, self.bytelen) raise ValueError("Length of {} != {}", v, self.bytelen)
io_out.write(v) io_out.write(v)
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[bytes]:
val = io_in.read(self.bytelen) val = io_in.read(self.bytelen)
if len(val) == 0: if len(val) == 0:
return None return None
@ -182,13 +185,13 @@ class BigSizeType(FieldType):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[int, str]:
a, b = split_field(s) a, b = split_field(s)
return int(a), b return int(a), b
# For the convenience of TLV header parsing # For the convenience of TLV header parsing
@staticmethod @staticmethod
def write(io_out, v, otherfields=None): def write(io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any] = {}) -> None:
if v < 253: if v < 253:
io_out.write(bytes([v])) io_out.write(bytes([v]))
elif v < 2**16: elif v < 2**16:
@ -199,7 +202,7 @@ class BigSizeType(FieldType):
io_out.write(bytes([255]) + struct.pack('>Q', v)) io_out.write(bytes([255]) + struct.pack('>Q', v))
@staticmethod @staticmethod
def read(io_in, otherfields=None): def read(io_in: BufferedIOBase, otherfields: Dict[str, Any] = {}) -> Optional[int]:
"Returns value, or None on EOF" "Returns value, or None on EOF"
b = io_in.read(1) b = io_in.read(1)
if len(b) == 0: if len(b) == 0:
@ -213,7 +216,7 @@ class BigSizeType(FieldType):
else: else:
return try_unpack('BigSize', io_in, '>Q', empty_ok=False) return try_unpack('BigSize', io_in, '>Q', empty_ok=False)
def val_to_str(self, v, otherfields): def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str:
return "{}".format(int(v)) return "{}".format(int(v))

175
contrib/pyln-proto/pyln/proto/message/message.py

@ -1,19 +1,20 @@
import struct import struct
import io from io import BufferedIOBase, BytesIO
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType
from .array_types import ( from .array_types import (
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
) )
from typing import Dict, List, Optional, Tuple, Any, cast
class MessageNamespace(object): class MessageNamespace(object):
"""A class which contains all FieldTypes and Messages in a particular """A class which contains all FieldTypes and Messages in a particular
domain, such as within a given BOLT""" domain, such as within a given BOLT"""
def __init__(self, csv_lines=[]): def __init__(self, csv_lines: List[str] = []):
self.subtypes = {} self.subtypes: Dict[str, SubtypeType] = {}
self.fundamentaltypes = {} self.fundamentaltypes: Dict[str, SubtypeType] = {}
self.tlvtypes = {} self.tlvtypes: Dict[str, TlvStreamType] = {}
self.messagetypes = {} self.messagetypes: Dict[str, MessageType] = {}
# For convenience, basic types go in every namespace # For convenience, basic types go in every namespace
for t in fundamental_types(): for t in fundamental_types():
@ -21,7 +22,7 @@ domain, such as within a given BOLT"""
self.load_csv(csv_lines) self.load_csv(csv_lines)
def __add__(self, other): def __add__(self, other: 'MessageNamespace'):
ret = MessageNamespace() ret = MessageNamespace()
ret.subtypes = self.subtypes.copy() ret.subtypes = self.subtypes.copy()
for v in other.subtypes.values(): for v in other.subtypes.values():
@ -34,57 +35,57 @@ domain, such as within a given BOLT"""
ret.add_messagetype(v) ret.add_messagetype(v)
return ret return ret
def add_subtype(self, t): def add_subtype(self, t: 'SubtypeType') -> None:
prev = self.get_type(t.name) prev = self.get_type(t.name)
if prev: if prev:
return ValueError('Already have {}'.format(prev)) raise ValueError('Already have {}'.format(prev))
self.subtypes[t.name] = t self.subtypes[t.name] = t
def add_fundamentaltype(self, t): def add_fundamentaltype(self, t: 'SubtypeType') -> None:
assert not self.get_type(t.name) assert not self.get_type(t.name)
self.fundamentaltypes[t.name] = t self.fundamentaltypes[t.name] = t
def add_tlvtype(self, t): def add_tlvtype(self, t: 'TlvStreamType') -> None:
prev = self.get_type(t.name) prev = self.get_type(t.name)
if prev: if prev:
return ValueError('Already have {}'.format(prev)) raise ValueError('Already have {}'.format(prev))
self.tlvtypes[t.name] = t self.tlvtypes[t.name] = t
def add_messagetype(self, m): def add_messagetype(self, m: 'MessageType') -> None:
if self.get_msgtype(m.name): if self.get_msgtype(m.name):
return ValueError('{}: message already exists'.format(m.name)) raise ValueError('{}: message already exists'.format(m.name))
if self.get_msgtype_by_number(m.number): if self.get_msgtype_by_number(m.number):
return ValueError('{}: message {} already number {}'.format( raise ValueError('{}: message {} already number {}'.format(
m.name, self.get_msg_by_number(m.number), m.number)) m.name, self.get_msgtype_by_number(m.number), m.number))
self.messagetypes[m.name] = m self.messagetypes[m.name] = m
def get_msgtype(self, name): def get_msgtype(self, name: str) -> Optional['MessageType']:
if name in self.messagetypes: if name in self.messagetypes:
return self.messagetypes[name] return self.messagetypes[name]
return None return None
def get_msgtype_by_number(self, num): def get_msgtype_by_number(self, num: int) -> Optional['MessageType']:
for m in self.messagetypes.values(): for m in self.messagetypes.values():
if m.number == num: if m.number == num:
return m return m
return None return None
def get_fundamentaltype(self, name): def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']:
if name in self.fundamentaltypes: if name in self.fundamentaltypes:
return self.fundamentaltypes[name] return self.fundamentaltypes[name]
return None return None
def get_subtype(self, name): def get_subtype(self, name: str) -> Optional['SubtypeType']:
if name in self.subtypes: if name in self.subtypes:
return self.subtypes[name] return self.subtypes[name]
return None return None
def get_tlvtype(self, name): def get_tlvtype(self, name: str) -> Optional['TlvStreamType']:
if name in self.tlvtypes: if name in self.tlvtypes:
return self.tlvtypes[name] return self.tlvtypes[name]
return None return None
def get_type(self, name): def get_type(self, name: str) -> Optional['SubtypeType']:
t = self.get_fundamentaltype(name) t = self.get_fundamentaltype(name)
if t is None: if t is None:
t = self.get_subtype(name) t = self.get_subtype(name)
@ -92,15 +93,9 @@ domain, such as within a given BOLT"""
t = self.get_tlvtype(name) t = self.get_tlvtype(name)
return t return t
def get_tlv_by_number(self, num): def load_csv(self, lines: List[str]) -> None:
for t in self.tlvtypes:
if t.number == num:
return t
return None
def load_csv(self, lines):
"""Load a series of comma-separate-value lines into the namespace""" """Load a series of comma-separate-value lines into the namespace"""
vals = {'msgtype': [], vals: Dict[str, List[List[str]]] = {'msgtype': [],
'msgdata': [], 'msgdata': [],
'tlvtype': [], 'tlvtype': [],
'tlvdata': [], 'tlvdata': [],
@ -114,39 +109,39 @@ domain, such as within a given BOLT"""
# Types can refer to other types, so add data last. # Types can refer to other types, so add data last.
for parts in vals['msgtype']: for parts in vals['msgtype']:
self.add_messagetype(MessageType.type_from_csv(parts)) self.add_messagetype(MessageType.msgtype_from_csv(parts))
for parts in vals['subtype']: for parts in vals['subtype']:
self.add_subtype(SubtypeType.type_from_csv(parts)) self.add_subtype(SubtypeType.subtype_from_csv(parts))
for parts in vals['tlvtype']: for parts in vals['tlvtype']:
TlvStreamType.type_from_csv(self, parts) TlvStreamType.tlvtype_from_csv(self, parts)
for parts in vals['msgdata']: for parts in vals['msgdata']:
MessageType.field_from_csv(self, parts) MessageType.msgfield_from_csv(self, parts)
for parts in vals['subtypedata']: for parts in vals['subtypedata']:
SubtypeType.field_from_csv(self, parts) SubtypeType.subfield_from_csv(self, parts)
for parts in vals['tlvdata']: for parts in vals['tlvdata']:
TlvStreamType.field_from_csv(self, parts) TlvStreamType.tlvfield_from_csv(self, parts)
class MessageTypeField(object): class MessageTypeField(object):
"""A field within a particular message type or subtype""" """A field within a particular message type or subtype"""
def __init__(self, ownername, name, fieldtype, option=None): def __init__(self, ownername: str, name: str, fieldtype: FieldType, option: Optional[str] = None):
self.full_name = "{}.{}".format(ownername, name) self.full_name = "{}.{}".format(ownername, name)
self.name = name self.name = name
self.fieldtype = fieldtype self.fieldtype = fieldtype
self.option = option self.option = option
def missing_fields(self, fields): def missing_fields(self, fieldvals: Dict[str, Any]):
"""Return this field if it's not in fields""" """Return this field if it's not in fields"""
if self.name not in fields and not self.option and not self.fieldtype.is_optional(): if self.name not in fieldvals and not self.option and not self.fieldtype.is_optional():
return [self] return [self]
return [] return []
def len_fields_bad(self, fieldname, otherfields): def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]:
return self.fieldtype.len_fields_bad(fieldname, otherfields) return self.fieldtype.len_fields_bad(fieldname, otherfields)
def __str__(self): def __str__(self):
@ -163,17 +158,17 @@ other types. Since 'msgtype' and 'tlvtype' are almost identical, they
inherit from this too. inherit from this too.
""" """
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name = name
self.fields = [] self.fields: List[FieldType] = []
def find_field(self, fieldname): def find_field(self, fieldname: str):
for f in self.fields: for f in self.fields:
if f.name == fieldname: if f.name == fieldname:
return f return f
return None return None
def add_field(self, field): def add_field(self, field: FieldType):
if self.find_field(field.name): if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field)) raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field) self.fields.append(field)
@ -181,8 +176,8 @@ inherit from this too.
def __str__(self): def __str__(self):
return "subtype-{}".format(self.name) return "subtype-{}".format(self.name)
def len_fields_bad(self, fieldname, otherfields): def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]:
bad_fields = [] bad_fields: List[str] = []
for f in self.fields: for f in self.fields:
bad_fields += f.len_fields_bad('{}.{}'.format(fieldname, f.name), bad_fields += f.len_fields_bad('{}.{}'.format(fieldname, f.name),
otherfields) otherfields)
@ -190,14 +185,14 @@ inherit from this too.
return bad_fields return bad_fields
@staticmethod @staticmethod
def type_from_csv(parts): def subtype_from_csv(parts: List[str]) -> 'SubtypeType':
"""e.g subtype,channel_update_timestamps""" """e.g subtype,channel_update_timestamps"""
if len(parts) != 1: if len(parts) != 1:
raise ValueError("subtype expected 2 CSV parts, not {}" raise ValueError("subtype expected 2 CSV parts, not {}"
.format(parts)) .format(parts))
return SubtypeType(parts[0]) return SubtypeType(parts[0])
def _field_from_csv(self, namespace, parts, ellipsisok=False, option=None): def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField:
"""Takes msgdata/subtypedata after first two fields """Takes msgdata/subtypedata after first two fields
e.g. [...]timestamp_node_id_1,u32, e.g. [...]timestamp_node_id_1,u32,
@ -236,12 +231,12 @@ inherit from this too.
return field return field
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
if not s.startswith('{'): if not s.startswith('{'):
raise ValueError("subtype {} must be wrapped in '{{}}': bad {}" raise ValueError("subtype {} must be wrapped in '{{}}': bad {}"
.format(self, s)) .format(self, s))
s = s[1:] s = s[1:]
ret = {} ret: Dict[str, Any] = {}
# FIXME: perhaps allow unlabelled fields to imply assign fields in order? # FIXME: perhaps allow unlabelled fields to imply assign fields in order?
while not s.startswith('}'): while not s.startswith('}'):
fieldname, s = s.split('=', 1) fieldname, s = s.split('=', 1)
@ -259,7 +254,7 @@ inherit from this too.
return ret, s[1:] return ret, s[1:]
def _raise_if_badvals(self, v): def _raise_if_badvals(self, v: Dict[str, Any]) -> None:
# Every non-optional value must be specified, and no others. # Every non-optional value must be specified, and no others.
defined = set([f.name for f in self.fields]) defined = set([f.name for f in self.fields])
have = set(v) have = set(v)
@ -272,7 +267,7 @@ inherit from this too.
if not f.fieldtype.is_optional(): if not f.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(f)) raise ValueError("Missing value for {}".format(f))
def val_to_str(self, v, otherfields): def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
self._raise_if_badvals(v) self._raise_if_badvals(v)
s = '' s = ''
sep = '' sep = ''
@ -283,13 +278,13 @@ inherit from this too.
return '{' + s + '}' return '{' + s + '}'
def write(self, io_out, v, otherfields): def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
self._raise_if_badvals(v) self._raise_if_badvals(v)
for fname, val in v.items(): for fname, val in v.items():
field = self.find_field(fname) field = self.find_field(fname)
field.fieldtype.write(io_out, val, otherfields) field.fieldtype.write(io_out, val, otherfields)
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
vals = {} vals = {}
for field in self.fields: for field in self.fields:
val = field.fieldtype.read(io_in, otherfields) val = field.fieldtype.read(io_in, otherfields)
@ -302,7 +297,7 @@ inherit from this too.
return vals return vals
@staticmethod @staticmethod
def field_from_csv(namespace, parts): def subfield_from_csv(namespace: MessageNamespace, parts: List[str]) -> None:
"""e.g """e.g
subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,""" subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,"""
if len(parts) != 4: if len(parts) != 4:
@ -330,12 +325,12 @@ class MessageType(SubtypeType):
'NODE': 0x2000, 'NODE': 0x2000,
'UPDATE': 0x1000} 'UPDATE': 0x1000}
def __init__(self, name, value, option=None): def __init__(self, name: str, value: str, option: Optional[str] = None):
super().__init__(name) super().__init__(name)
self.number = self.parse_value(value) self.number = self.parse_value(value)
self.option = option self.option = option
def parse_value(self, value): def parse_value(self, value: str) -> int:
result = 0 result = 0
for token in value.split('|'): for token in value.split('|'):
if token in self.onion_types.keys(): if token in self.onion_types.keys():
@ -349,7 +344,7 @@ class MessageType(SubtypeType):
return "msgtype-{}".format(self.name) return "msgtype-{}".format(self.name)
@staticmethod @staticmethod
def type_from_csv(parts): def msgtype_from_csv(parts: List[str]) -> 'MessageType':
"""e.g msgtype,open_channel,32,option_foo""" """e.g msgtype,open_channel,32,option_foo"""
option = None option = None
if len(parts) == 3: if len(parts) == 3:
@ -360,7 +355,7 @@ class MessageType(SubtypeType):
return MessageType(parts[0], parts[1], option) return MessageType(parts[0], parts[1], option)
@staticmethod @staticmethod
def field_from_csv(namespace, parts): def msgfield_from_csv(namespace: MessageNamespace, parts: List[str]) -> None:
"""e.g msgdata,open_channel,temporary_channel_id,byte,32[,opt]""" """e.g msgdata,open_channel,temporary_channel_id,byte,32[,opt]"""
option = None option = None
if len(parts) == 5: if len(parts) == 5:
@ -390,18 +385,18 @@ confusingly) refers to them.
def __str__(self): def __str__(self):
return "tlvstreamtype-{}".format(self.name) return "tlvstreamtype-{}".format(self.name)
def find_field_by_number(self, num): def find_field_by_number(self, num: int) -> Optional['TlvMessageType']:
for f in self.fields: for f in self.fields:
if f.number == num: if f.number == num:
return f return f
return None return None
def is_optional(self): def is_optional(self) -> bool:
"""You can omit a tlvstream= altogether""" """You can omit a tlvstream= altogether"""
return True return True
@staticmethod @staticmethod
def type_from_csv(namespace, parts): def tlvtype_from_csv(namespace: MessageNamespace, parts: List[str]) -> None:
"""e.g tlvtype,reply_channel_range_tlvs,timestamps_tlv,1""" """e.g tlvtype,reply_channel_range_tlvs,timestamps_tlv,1"""
if len(parts) != 3: if len(parts) != 3:
raise ValueError("tlvtype expected 4 CSV parts, not {}" raise ValueError("tlvtype expected 4 CSV parts, not {}"
@ -414,7 +409,7 @@ confusingly) refers to them.
tlvstream.add_field(TlvMessageType(parts[1], parts[2])) tlvstream.add_field(TlvMessageType(parts[1], parts[2]))
@staticmethod @staticmethod
def field_from_csv(namespace, parts): def tlvfield_from_csv(namespace: MessageNamespace, parts: List[str]) -> None:
"""e.g """e.g
tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
@ -435,20 +430,21 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True) subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True)
field.add_field(subfield) field.add_field(subfield)
def val_from_str(self, s): def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
"""{fieldname={...},...}. Returns dict of fieldname->val""" """{fieldname={...},...}. Returns dict of fieldname->val"""
if not s.startswith('{'): if not s.startswith('{'):
raise ValueError("tlvtype {} must be wrapped in '{{}}': bad {}" raise ValueError("tlvtype {} must be wrapped in '{{}}': bad {}"
.format(self, s)) .format(self, s))
s = s[1:] s = s[1:]
ret = {} ret: Dict[str, Any] = {}
while not s.startswith('}'): while not s.startswith('}'):
fieldname, s = s.split('=', 1) fieldname, s = s.split('=', 1)
f = self.find_field(fieldname) f = self.find_field(fieldname)
if f is None: if f is None:
# Unknown fields are number=hexstring # Unknown fields are number=hexstring
hexstring, s = split_field(s) hexstring, s = split_field(s)
ret[int(fieldname)] = bytes.fromhex(hexstring) # Make sure it is actually a valid int!
ret[str(int(fieldname))] = bytes.fromhex(hexstring)
else: else:
ret[fieldname], s = f.val_from_str(s) ret[fieldname], s = f.val_from_str(s)
if s[0] == ',': if s[0] == ',':
@ -456,7 +452,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return ret, s[1:] return ret, s[1:]
def val_to_str(self, v, otherfields): def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
s = '' s = ''
sep = '' sep = ''
for fieldname in v: for fieldname in v:
@ -470,14 +466,14 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return '{' + s + '}' return '{' + s + '}'
def write(self, iobuf, v, otherfields): def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
# If they didn't specify this tlvstream, it's empty. # If they didn't specify this tlvstream, it's empty.
if v is None: if v is None:
return return
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into # Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires. # ascending order as TLV spec requires.
def write_raw_val(iobuf, val, otherfields): def write_raw_val(iobuf, val, otherfields: Dict[str, Any]):
iobuf.write(val) iobuf.write(val)
def get_value(tup): def get_value(tup):
@ -496,14 +492,14 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
ordered.sort(key=get_value) ordered.sort(key=get_value)
for typenum, writefunc, val in ordered: for typenum, writefunc, val in ordered:
buf = io.BytesIO() buf = BytesIO()
writefunc(buf, val, otherfields) writefunc(buf, val, otherfields)
BigSizeType.write(iobuf, typenum) BigSizeType.write(io_out, typenum)
BigSizeType.write(iobuf, len(buf.getvalue())) BigSizeType.write(io_out, len(buf.getvalue()))
iobuf.write(buf.getvalue()) io_out.write(buf.getvalue())
def read(self, io_in, otherfields): def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
vals = {} vals: Dict[str, Any] = {}
while True: while True:
tlv_type = BigSizeType.read(io_in) tlv_type = BigSizeType.read(io_in)
@ -522,17 +518,18 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
# Raw fields are allowed, just index by number. # Raw fields are allowed, just index by number.
vals[tlv_type] = binval vals[tlv_type] = binval
else: else:
vals[f.name] = f.read(io.BytesIO(binval), otherfields) # FIXME: Why doesn't mypy think BytesIO is a valid BufferedIOBase?
vals[f.name] = f.read(cast(BufferedIOBase, BytesIO(binval)), otherfields)
def name_and_val(self, name, v): def name_and_val(self, name: str, v: Dict[str, Any]) -> str:
"""This is overridden by LengthFieldType to return nothing""" """This is overridden by LengthFieldType to return nothing"""
return " {}={}".format(name, self.val_to_str(v, None)) return " {}={}".format(name, self.val_to_str(v, {}))
class TlvMessageType(MessageType): class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak""" """A 'tlvtype' in BOLT-speak"""
def __init__(self, name, value): def __init__(self, name: str, value: str):
super().__init__(name, value) super().__init__(name, value)
def __str__(self): def __str__(self):
@ -541,10 +538,10 @@ class TlvMessageType(MessageType):
class Message(object): class Message(object):
"""A particular message instance""" """A particular message instance"""
def __init__(self, messagetype, **kwargs): def __init__(self, messagetype: MessageType, **kwargs):
"""MessageType is the type of this msg, with fields. Fields can either be valid values for the type, or if they are strings they are converted according to the field type""" """MessageType is the type of this msg, with fields. Fields can either be valid values for the type, or if they are strings they are converted according to the field type"""
self.messagetype = messagetype self.messagetype = messagetype
self.fields = {} self.fields: Dict[str, Any] = {}
# Convert arguments from strings to values if necessary. # Convert arguments from strings to values if necessary.
for field in kwargs: for field in kwargs:
@ -564,16 +561,16 @@ class Message(object):
if bad_lens: if bad_lens:
raise ValueError("Inconsistent length fields: {}".format(bad_lens)) raise ValueError("Inconsistent length fields: {}".format(bad_lens))
def missing_fields(self): def missing_fields(self) -> List[str]:
"""Are any required fields missing?""" """Are any required fields missing?"""
missing = [] missing: List[str] = []
for ftype in self.messagetype.fields: for ftype in self.messagetype.fields:
missing += ftype.missing_fields(self.fields) missing += ftype.missing_fields(self.fields)
return missing return missing
@staticmethod @staticmethod
def read(namespace, io_in): def read(namespace: MessageNamespace, io_in: BufferedIOBase) -> Optional['Message']:
"""Read and decode a Message within that namespace. """Read and decode a Message within that namespace.
Returns None on EOF Returns None on EOF
@ -587,7 +584,7 @@ Returns None on EOF
if mtype is None: if mtype is None:
raise ValueError('Unknown message type number {}'.format(typenum)) raise ValueError('Unknown message type number {}'.format(typenum))
fields = {} fields: Dict[str, Any] = {}
for f in mtype.fields: for f in mtype.fields:
fields[f.name] = f.fieldtype.read(io_in, fields) fields[f.name] = f.fieldtype.read(io_in, fields)
if fields[f.name] is None: if fields[f.name] is None:
@ -598,7 +595,7 @@ Returns None on EOF
return Message(mtype, **fields) return Message(mtype, **fields)
@staticmethod @staticmethod
def from_str(namespace, s, incomplete_ok=False): def from_str(namespace: MessageNamespace, s: str, incomplete_ok=False) -> 'Message':
"""Decode a string to a Message within that namespace. """Decode a string to a Message within that namespace.
Format is msgname [ field=...]*. Format is msgname [ field=...]*.
@ -624,7 +621,7 @@ Format is msgname [ field=...]*.
return m return m
def write(self, io_out): def write(self, io_out: BufferedIOBase) -> None:
"""Write a Message into its wire format. """Write a Message into its wire format.
Must not have missing fields. Must not have missing fields.
@ -645,7 +642,7 @@ Must not have missing fields.
val = None val = None
f.fieldtype.write(io_out, val, self.fields) f.fieldtype.write(io_out, val, self.fields)
def to_str(self): def to_str(self) -> str:
"""Encode a Message into a string""" """Encode a Message into a string"""
ret = "{}".format(self.messagetype.name) ret = "{}".format(self.messagetype.name)
for f in self.messagetype.fields: for f in self.messagetype.fields:

Loading…
Cancel
Save