diff --git a/contrib/pyln-proto/pyln/proto/message/__init__.py b/contrib/pyln-proto/pyln/proto/message/__init__.py index 7d6e602b6..f095f6f60 100644 --- a/contrib/pyln-proto/pyln/proto/message/__init__.py +++ b/contrib/pyln-proto/pyln/proto/message/__init__.py @@ -7,4 +7,20 @@ __all__ = [ "MessageType", "Message", "SubtypeType", + + # fundamental_types + 'byte', + 'u16', + 'u32', + 'u64', + 'tu16', + 'tu32', + 'tu64', + 'chain_hash', + 'channel_id', + 'sha256', + 'point', + 'short_channel_id', + 'signature', + 'bigsize', ] diff --git a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py index 344a48ad8..972f18664 100644 --- a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py +++ b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py @@ -1,5 +1,6 @@ import struct import io +import sys from typing import Optional @@ -235,3 +236,9 @@ def fundamental_types(): # FIXME: See https://github.com/lightningnetwork/lightning-rfc/pull/778 BigSizeType('varint'), ] + + +# Expose these as native types. +mod = sys.modules[FieldType.__module__] +for m in fundamental_types(): + setattr(mod, m.name, m) diff --git a/contrib/pyln-proto/tests/test_array_types.py b/contrib/pyln-proto/tests/test_array_types.py index caf1a4bf3..2c1c9f227 100644 --- a/contrib/pyln-proto/tests/test_array_types.py +++ b/contrib/pyln-proto/tests/test_array_types.py @@ -1,18 +1,10 @@ #! /usr/bin/python3 -from pyln.proto.message.fundamental_types import fundamental_types +from pyln.proto.message.fundamental_types import byte, u16, short_channel_id from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType import io def test_sized_array(): - # Steal two fundamental types for testing - for t in fundamental_types(): - if t.name == 'byte': - byte = t - if t.name == 'u16': - u16 = t - if t.name == 'short_channel_id': - scid = t # Simple class to make outer work. class dummy: @@ -25,7 +17,7 @@ def test_sized_array(): [SizedArrayType(dummy("test2"), "test_arr", u16, 4), "[0,1,2,256]", bytes([0, 0, 0, 1, 0, 2, 1, 0])], - [SizedArrayType(dummy("test3"), "test_arr", scid, 4), + [SizedArrayType(dummy("test3"), "test_arr", short_channel_id, 4), "[1x2x3,4x5x6,7x8x9,10x11x12]", bytes([0, 0, 1, 0, 0, 2, 0, 3] + [0, 0, 4, 0, 0, 5, 0, 6] @@ -41,15 +33,6 @@ def test_sized_array(): def test_ellipsis_array(): - # Steal two fundamental types for testing - for t in fundamental_types(): - if t.name == 'byte': - byte = t - if t.name == 'u16': - u16 = t - if t.name == 'short_channel_id': - scid = t - # Simple class to make outer work. class dummy: def __init__(self, name): @@ -61,7 +44,7 @@ def test_ellipsis_array(): [EllipsisArrayType(dummy("test2"), "test_arr", u16), "[0,1,2,256]", bytes([0, 0, 0, 1, 0, 2, 1, 0])], - [EllipsisArrayType(dummy("test3"), "test_arr", scid), + [EllipsisArrayType(dummy("test3"), "test_arr", short_channel_id), "[1x2x3,4x5x6,7x8x9,10x11x12]", bytes([0, 0, 1, 0, 0, 2, 0, 3] + [0, 0, 4, 0, 0, 5, 0, 6] @@ -77,15 +60,6 @@ def test_ellipsis_array(): def test_dynamic_array(): - # Steal two fundamental types for testing - for t in fundamental_types(): - if t.name == 'byte': - byte = t - if t.name == 'u16': - u16 = t - if t.name == 'short_channel_id': - scid = t - # Simple class to make outer. class dummy: def __init__(self, name): @@ -106,7 +80,7 @@ def test_dynamic_array(): lenfield), "[0,1,2,256]", bytes([0, 0, 0, 1, 0, 2, 1, 0])], - [DynamicArrayType(dummy("test3"), "test_arr", scid, + [DynamicArrayType(dummy("test3"), "test_arr", short_channel_id, lenfield), "[1x2x3,4x5x6,7x8x9,10x11x12]", bytes([0, 0, 1, 0, 0, 2, 0, 3]