Browse Source

pyln: Add type-annotations to plugin.py

This should help users that have type-checking enabled.
travis-experimental
Christian Decker 5 years ago
committed by Rusty Russell
parent
commit
49ec800a07
  1. 262
      contrib/pyln-client/pyln/client/plugin.py
  2. 79
      contrib/pyln-proto/pyln/proto/primitives.py
  3. 71
      contrib/pyln-proto/pyln/proto/wire.py

262
contrib/pyln-client/pyln/client/plugin.py

@ -1,10 +1,12 @@
from .lightning import LightningRpc, Millisatoshi
from binascii import hexlify from binascii import hexlify
from collections import OrderedDict from collections import OrderedDict
from enum import Enum from enum import Enum
from .lightning import LightningRpc, Millisatoshi
from threading import RLock from threading import RLock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import inspect import inspect
import io
import json import json
import math import math
import os import os
@ -12,6 +14,16 @@ import re
import sys import sys
import traceback import traceback
# Notice that this definition is incomplete as it only checks the
# top-level. Arrays and Dicts could contain types that aren't encodeable. This
# limitation stems from the fact that recursive types are not really supported
# yet.
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
# Yes, decorators are weird...
NoneDecoratorType = Callable[..., Callable[..., None]]
JsonDecoratorType = Callable[..., Callable[..., JSONType]]
class MethodType(Enum): class MethodType(Enum):
RPCMETHOD = 0 RPCMETHOD = 0
@ -32,8 +44,10 @@ class Method(object):
- RPC exposed by RPC passthrough - RPC exposed by RPC passthrough
- HOOK registered to be called synchronously by lightningd - HOOK registered to be called synchronously by lightningd
""" """
def __init__(self, name, func, mtype=MethodType.RPCMETHOD, category=None, def __init__(self, name: str, func: Callable[..., JSONType],
desc=None, long_desc=None, deprecated=False): mtype: MethodType = MethodType.RPCMETHOD,
category: str = None, desc: str = None,
long_desc: str = None, deprecated: bool = False):
self.name = name self.name = name
self.func = func self.func = func
self.mtype = mtype self.mtype = mtype
@ -47,7 +61,8 @@ class Method(object):
class Request(dict): class Request(dict):
"""A request object that wraps params and allows async return """A request object that wraps params and allows async return
""" """
def __init__(self, plugin, req_id, method, params, background=False): def __init__(self, plugin: 'Plugin', req_id: Optional[int], method: str,
params: Any, background: bool = False):
self.method = method self.method = method
self.params = params self.params = params
self.background = background self.background = background
@ -55,15 +70,19 @@ class Request(dict):
self.state = RequestState.PENDING self.state = RequestState.PENDING
self.id = req_id self.id = req_id
def getattr(self, key): def getattr(self, key: str) -> Union[Method, Any, int]:
if key == "params": if key == "params":
return self.params return self.params
elif key == "id": elif key == "id":
return self.id return self.id
elif key == "method": elif key == "method":
return self.method return self.method
else:
raise ValueError(
'Cannot get attribute "{key}" on Request'.format(key=key)
)
def set_result(self, result): def set_result(self, result: Any) -> None:
if self.state != RequestState.PENDING: if self.state != RequestState.PENDING:
raise ValueError( raise ValueError(
"Cannot set the result of a request that is not pending, " "Cannot set the result of a request that is not pending, "
@ -75,7 +94,7 @@ class Request(dict):
'result': self.result 'result': self.result
}) })
def set_exception(self, exc): def set_exception(self, exc: Exception) -> None:
if self.state != RequestState.PENDING: if self.state != RequestState.PENDING:
raise ValueError( raise ValueError(
"Cannot set the exception of a request that is not pending, " "Cannot set the exception of a request that is not pending, "
@ -93,7 +112,7 @@ class Request(dict):
}, },
}) })
def _write_result(self, result): def _write_result(self, result: dict) -> None:
self.plugin._write_locked(result) self.plugin._write_locked(result)
@ -126,12 +145,20 @@ class Plugin(object):
""" """
def __init__(self, stdout=None, stdin=None, autopatch=True, dynamic=True, def __init__(self, stdout: Optional[io.TextIOBase] = None,
init_features=None, node_features=None, invoice_features=None): stdin: Optional[io.TextIOBase] = None, autopatch: bool = True,
self.methods = {'init': Method('init', self._init, MethodType.RPCMETHOD)} dynamic: bool = True,
self.options = {} init_features: Optional[Union[int, str, bytes]] = None,
node_features: Optional[Union[int, str, bytes]] = None,
invoice_features: Optional[Union[int, str, bytes]] = None):
self.methods = {
'init': Method('init', self._init, MethodType.RPCMETHOD)
}
self.options: Dict[str, Dict[str, Any]] = {}
def convert_featurebits(bits): def convert_featurebits(
bits: Optional[Union[int, str, bytes]]) -> Optional[str]:
"""Convert the featurebits into the bytes required to hexencode. """Convert the featurebits into the bytes required to hexencode.
""" """
if bits is None: if bits is None:
@ -149,7 +176,9 @@ class Plugin(object):
return hexlify(bits).decode('ASCII') return hexlify(bits).decode('ASCII')
else: else:
raise ValueError("Could not convert featurebits to hex-encoded string") raise ValueError(
"Could not convert featurebits to hex-encoded string"
)
self.featurebits = { self.featurebits = {
'init': convert_featurebits(init_features), 'init': convert_featurebits(init_features),
@ -158,7 +187,7 @@ class Plugin(object):
} }
# A dict from topics to handler functions # A dict from topics to handler functions
self.subscriptions = {} self.subscriptions: Dict[str, Callable[..., None]] = {}
if not stdout: if not stdout:
self.stdout = sys.stdout self.stdout = sys.stdout
@ -172,17 +201,21 @@ class Plugin(object):
monkey_patch(self, stdout=True, stderr=True) monkey_patch(self, stdout=True, stderr=True)
self.add_method("getmanifest", self._getmanifest, background=False) self.add_method("getmanifest", self._getmanifest, background=False)
self.rpc_filename = None self.rpc_filename: Optional[str] = None
self.lightning_dir = None self.lightning_dir: Optional[str] = None
self.rpc = None self.rpc: Optional[LightningRpc] = None
self.startup = True self.startup = True
self.dynamic = dynamic self.dynamic = dynamic
self.child_init = None self.child_init: Optional[Callable[..., None]] = None
self.write_lock = RLock() self.write_lock = RLock()
def add_method(self, name, func, background=False, category=None, desc=None, def add_method(self, name: str, func: Callable[..., Any],
long_desc=None, deprecated=False): background: bool = False,
category: Optional[str] = None,
desc: Optional[str] = None,
long_desc: Optional[str] = None,
deprecated: bool = False) -> None:
"""Add a plugin method to the dispatch table. """Add a plugin method to the dispatch table.
The function will be expected at call time (see `_dispatch`) The function will be expected at call time (see `_dispatch`)
@ -221,11 +254,15 @@ class Plugin(object):
) )
# Register the function with the name # Register the function with the name
method = Method(name, func, MethodType.RPCMETHOD, category, desc, long_desc, deprecated) method = Method(
name, func, MethodType.RPCMETHOD, category, desc, long_desc,
deprecated
)
method.background = background method.background = background
self.methods[name] = method self.methods[name] = method
def add_subscription(self, topic, func): def add_subscription(self, topic: str, func: Callable[..., None]) -> None:
"""Add a subscription to our list of subscriptions. """Add a subscription to our list of subscriptions.
A subscription is an association between a topic and a handler A subscription is an association between a topic and a handler
@ -243,9 +280,9 @@ class Plugin(object):
"Topic {} already has a handler".format(topic) "Topic {} already has a handler".format(topic)
) )
# Make sure the notification callback has a **kwargs argument so that it # Make sure the notification callback has a **kwargs argument so that
# doesn't break if we add more arguments to the call later on. Issue a # it doesn't break if we add more arguments to the call later
# warning if it does not. # on. Issue a warning if it does not.
s = inspect.signature(func) s = inspect.signature(func)
kinds = [p.kind for p in s.parameters.values()] kinds = [p.kind for p in s.parameters.values()]
if inspect.Parameter.VAR_KEYWORD not in kinds: if inspect.Parameter.VAR_KEYWORD not in kinds:
@ -257,16 +294,20 @@ class Plugin(object):
self.subscriptions[topic] = func self.subscriptions[topic] = func
def subscribe(self, topic): def subscribe(self, topic: str) -> NoneDecoratorType:
"""Function decorator to register a notification handler. """Function decorator to register a notification handler.
""" """
def decorator(f): # Yes, decorator type annotations are just weird, don't think too much
# about it...
def decorator(f: Callable[..., None]) -> Callable[..., None]:
self.add_subscription(topic, f) self.add_subscription(topic, f)
return f return f
return decorator return decorator
def add_option(self, name, default, description, opt_type="string", def add_option(self, name: str, default: Optional[str],
deprecated=False): description: Optional[str],
opt_type: str = "string", deprecated: bool = False) -> None:
"""Add an option that we'd like to register with lightningd. """Add an option that we'd like to register with lightningd.
Needs to be called before `Plugin.run`, otherwise we might not Needs to be called before `Plugin.run`, otherwise we might not
@ -279,7 +320,9 @@ class Plugin(object):
) )
if opt_type not in ["string", "int", "bool", "flag"]: if opt_type not in ["string", "int", "bool", "flag"]:
raise ValueError('{} not in supported type set (string, int, bool, flag)') raise ValueError(
'{} not in supported type set (string, int, bool, flag)'
)
self.options[name] = { self.options[name] = {
'name': name, 'name': name,
@ -290,7 +333,8 @@ class Plugin(object):
'deprecated': deprecated, 'deprecated': deprecated,
} }
def add_flag_option(self, name, description, deprecated=False): def add_flag_option(self, name: str, description: str,
deprecated: bool = False) -> None:
"""Add a flag option that we'd like to register with lightningd. """Add a flag option that we'd like to register with lightningd.
Needs to be called before `Plugin.run`, otherwise we might not Needs to be called before `Plugin.run`, otherwise we might not
@ -300,7 +344,7 @@ class Plugin(object):
self.add_option(name, None, description, opt_type="flag", self.add_option(name, None, description, opt_type="flag",
deprecated=deprecated) deprecated=deprecated)
def get_option(self, name): def get_option(self, name: str) -> str:
if name not in self.options: if name not in self.options:
raise ValueError("No option with name {} registered".format(name)) raise ValueError("No option with name {} registered".format(name))
@ -309,31 +353,42 @@ class Plugin(object):
else: else:
return self.options[name]['default'] return self.options[name]['default']
def async_method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False): def async_method(self, method_name: str, category: Optional[str] = None,
desc: Optional[str] = None,
long_desc: Optional[str] = None,
deprecated: bool = False) -> NoneDecoratorType:
"""Decorator to add an async plugin method to the dispatch table. """Decorator to add an async plugin method to the dispatch table.
Internally uses add_method. Internally uses add_method.
""" """
def decorator(f): def decorator(f: Callable[..., None]) -> Callable[..., None]:
self.add_method(method_name, f, background=True, category=category, self.add_method(method_name, f, background=True, category=category,
desc=desc, long_desc=long_desc, desc=desc, long_desc=long_desc,
deprecated=deprecated) deprecated=deprecated)
return f return f
return decorator return decorator
def method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False): def method(self, method_name: str, category: Optional[str] = None,
desc: Optional[str] = None,
long_desc: Optional[str] = None,
deprecated: bool = False) -> JsonDecoratorType:
"""Decorator to add a plugin method to the dispatch table. """Decorator to add a plugin method to the dispatch table.
Internally uses add_method. Internally uses add_method.
""" """
def decorator(f): def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
self.add_method(method_name, f, background=False, category=category, self.add_method(method_name,
desc=desc, long_desc=long_desc, f,
background=False,
category=category,
desc=desc,
long_desc=long_desc,
deprecated=deprecated) deprecated=deprecated)
return f return f
return decorator return decorator
def add_hook(self, name, func, background=False): def add_hook(self, name: str, func: Callable[..., JSONType],
background: bool = False) -> None:
"""Register a hook that is called synchronously by lightningd on events """Register a hook that is called synchronously by lightningd on events
""" """
if name in self.methods: if name in self.methods:
@ -357,40 +412,47 @@ class Plugin(object):
method.background = background method.background = background
self.methods[name] = method self.methods[name] = method
def hook(self, method_name): def hook(self, method_name: str) -> JsonDecoratorType:
"""Decorator to add a plugin hook to the dispatch table. """Decorator to add a plugin hook to the dispatch table.
Internally uses add_hook. Internally uses add_hook.
""" """
def decorator(f): def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
self.add_hook(method_name, f, background=False) self.add_hook(method_name, f, background=False)
return f return f
return decorator return decorator
def async_hook(self, method_name): def async_hook(self, method_name: str) -> NoneDecoratorType:
"""Decorator to add an async plugin hook to the dispatch table. """Decorator to add an async plugin hook to the dispatch table.
Internally uses add_hook. Internally uses add_hook.
""" """
def decorator(f): def decorator(f: Callable[..., None]) -> Callable[..., None]:
self.add_hook(method_name, f, background=True) self.add_hook(method_name, f, background=True)
return f return f
return decorator return decorator
def init(self, *args, **kwargs): def init(self) -> NoneDecoratorType:
"""Decorator to add a function called after plugin initialization """Decorator to add a function called after plugin initialization
""" """
def decorator(f): def decorator(f: Callable[..., None]) -> Callable[..., None]:
if self.child_init is not None: if self.child_init is not None:
raise ValueError('The @plugin.init decorator should only be used once') raise ValueError(
'The @plugin.init decorator should only be used once'
)
self.child_init = f self.child_init = f
return f return f
return decorator return decorator
@staticmethod @staticmethod
def _coerce_arguments(func, ba): def _coerce_arguments(
func: Callable[..., Any],
ba: inspect.BoundArguments) -> inspect.BoundArguments:
args = OrderedDict() args = OrderedDict()
annotations = func.__annotations__ if hasattr(func, "__annotations__") else {} annotations = {}
if hasattr(func, "__annotations__"):
annotations = func.__annotations__
for key, val in ba.arguments.items(): for key, val in ba.arguments.items():
annotation = annotations.get(key, None) annotation = annotations.get(key, None)
if annotation is not None and annotation == Millisatoshi: if annotation is not None and annotation == Millisatoshi:
@ -400,7 +462,8 @@ class Plugin(object):
ba.arguments = args ba.arguments = args
return ba return ba
def _bind_pos(self, func, params, request): def _bind_pos(self, func: Callable[..., Any], params: List[str],
request: Request) -> inspect.BoundArguments:
"""Positional binding of parameters """Positional binding of parameters
""" """
assert(isinstance(params, list)) assert(isinstance(params, list))
@ -409,7 +472,7 @@ class Plugin(object):
# Collect injections so we can sort them and insert them in the right # Collect injections so we can sort them and insert them in the right
# order later. If we don't apply inject them in increasing order we # order later. If we don't apply inject them in increasing order we
# might shift away an earlier injection. # might shift away an earlier injection.
injections = [] injections: List[Tuple[int, Any]] = []
if 'plugin' in sig.parameters: if 'plugin' in sig.parameters:
pos = list(sig.parameters.keys()).index('plugin') pos = list(sig.parameters.keys()).index('plugin')
injections.append((pos, self)) injections.append((pos, self))
@ -425,7 +488,8 @@ class Plugin(object):
ba.apply_defaults() ba.apply_defaults()
return ba return ba
def _bind_kwargs(self, func, params, request): def _bind_kwargs(self, func: Callable[..., Any], params: Dict[str, Any],
request: Request) -> inspect.BoundArguments:
"""Keyword based binding of parameters """Keyword based binding of parameters
""" """
assert(isinstance(params, dict)) assert(isinstance(params, dict))
@ -445,7 +509,8 @@ class Plugin(object):
self._coerce_arguments(func, ba) self._coerce_arguments(func, ba)
return ba return ba
def _exec_func(self, func, request): def _exec_func(self, func: Callable[..., Any],
request: Request) -> JSONType:
params = request.params params = request.params
if isinstance(params, list): if isinstance(params, list):
ba = self._bind_pos(func, params, request) ba = self._bind_pos(func, params, request)
@ -454,9 +519,11 @@ class Plugin(object):
ba = self._bind_kwargs(func, params, request) ba = self._bind_kwargs(func, params, request)
return func(*ba.args, **ba.kwargs) return func(*ba.args, **ba.kwargs)
else: else:
raise TypeError("Parameters to function call must be either a dict or a list.") raise TypeError(
"Parameters to function call must be either a dict or a list."
)
def _dispatch_request(self, request): def _dispatch_request(self, request: Request) -> None:
name = request.method name = request.method
if name not in self.methods: if name not in self.methods:
@ -487,7 +554,7 @@ class Plugin(object):
request.set_exception(e) request.set_exception(e)
self.log(traceback.format_exc()) self.log(traceback.format_exc())
def _dispatch_notification(self, request): def _dispatch_notification(self, request: Request) -> None:
if request.method not in self.subscriptions: if request.method not in self.subscriptions:
raise ValueError("No subscription for {name} found.".format( raise ValueError("No subscription for {name} found.".format(
name=request.method)) name=request.method))
@ -498,15 +565,19 @@ class Plugin(object):
except Exception: except Exception:
self.log(traceback.format_exc()) self.log(traceback.format_exc())
def _write_locked(self, obj): def _write_locked(self, obj: JSONType) -> None:
# ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that, # ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that,
# then utf8 ourselves. # then utf8 ourselves.
s = bytes(json.dumps(obj, cls=LightningRpc.LightningJSONEncoder, ensure_ascii=False) + "\n\n", encoding='utf-8') s = bytes(json.dumps(
obj,
cls=LightningRpc.LightningJSONEncoder,
ensure_ascii=False
) + "\n\n", encoding='utf-8')
with self.write_lock: with self.write_lock:
self.stdout.buffer.write(s) self.stdout.buffer.write(s)
self.stdout.flush() self.stdout.flush()
def notify(self, method, params): def notify(self, method: str, params: JSONType) -> None:
payload = { payload = {
'jsonrpc': '2.0', 'jsonrpc': '2.0',
'method': method, 'method': method,
@ -514,30 +585,35 @@ class Plugin(object):
} }
self._write_locked(payload) self._write_locked(payload)
def log(self, message, level='info'): def log(self, message: str, level: str = 'info') -> None:
# Split the log into multiple lines and print them # Split the log into multiple lines and print them
# individually. Makes tracebacks much easier to read. # individually. Makes tracebacks much easier to read.
for line in message.split('\n'): for line in message.split('\n'):
self.notify('log', {'level': level, 'message': line}) self.notify('log', {'level': level, 'message': line})
def _parse_request(self, jsrequest): def _parse_request(self, jsrequest: Dict[str, JSONType]) -> Request:
i = jsrequest.get('id', None)
if not isinstance(i, int) and i is not None:
raise ValueError('Non-integer request id "{i}"'.format(i=i))
request = Request( request = Request(
plugin=self, plugin=self,
req_id=jsrequest.get('id', None), req_id=i,
method=jsrequest['method'], method=str(jsrequest['method']),
params=jsrequest['params'], params=jsrequest['params'],
background=False, background=False,
) )
return request return request
def _multi_dispatch(self, msgs): def _multi_dispatch(self, msgs: List[bytes]) -> bytes:
"""We received a couple of messages, now try to dispatch them all. """We received a couple of messages, now try to dispatch them all.
Returns the last partial message that was not complete yet. Returns the last partial message that was not complete yet.
""" """
for payload in msgs[:-1]: for payload in msgs[:-1]:
# Note that we use function annotations to do Millisatoshi conversions # Note that we use function annotations to do Millisatoshi
# in _exec_func, so we don't use LightningJSONDecoder here. # conversions in _exec_func, so we don't use LightningJSONDecoder
# here.
request = self._parse_request(json.loads(payload.decode('utf8'))) request = self._parse_request(json.loads(payload.decode('utf8')))
# If this has an 'id'-field, it's a request and returns a # If this has an 'id'-field, it's a request and returns a
@ -550,7 +626,7 @@ class Plugin(object):
return msgs[-1] return msgs[-1]
def run(self): def run(self) -> None:
partial = b"" partial = b""
for l in self.stdin.buffer: for l in self.stdin.buffer:
partial += l partial += l
@ -561,7 +637,7 @@ class Plugin(object):
partial = self._multi_dispatch(msgs) partial = self._multi_dispatch(msgs)
def _getmanifest(self, **kwargs): def _getmanifest(self, **kwargs) -> JSONType:
if 'allow-deprecated-apis' in kwargs: if 'allow-deprecated-apis' in kwargs:
self.deprecated_apis = kwargs['allow-deprecated-apis'] self.deprecated_apis = kwargs['allow-deprecated-apis']
else: else:
@ -582,13 +658,21 @@ class Plugin(object):
doc = inspect.getdoc(method.func) doc = inspect.getdoc(method.func)
if not doc: if not doc:
self.log( self.log(
'RPC method \'{}\' does not have a docstring.'.format(method.name) 'RPC method \'{}\' does not have a docstring.'.format(
method.name
)
) )
doc = "Undocumented RPC method from a plugin." doc = "Undocumented RPC method from a plugin."
doc = re.sub('\n+', ' ', doc) doc = re.sub('\n+', ' ', doc)
# Handles out-of-order use of parameters like: # Handles out-of-order use of parameters like:
# def hello_obfus(arg1, arg2, plugin, thing3, request=None, thing5='at', thing6=21) #
# ```python3
#
# def hello_obfus(arg1, arg2, plugin, thing3, request=None,
# thing5='at', thing6=21)
#
# ```
argspec = inspect.getfullargspec(method.func) argspec = inspect.getfullargspec(method.func)
defaults = argspec.defaults defaults = argspec.defaults
num_defaults = len(defaults) if defaults else 0 num_defaults = len(defaults) if defaults else 0
@ -611,7 +695,8 @@ class Plugin(object):
'description': doc if not method.desc else method.desc 'description': doc if not method.desc else method.desc
}) })
if method.long_desc: if method.long_desc:
methods[len(methods) - 1]["long_description"] = method.long_desc m = methods[len(methods) - 1]
m["long_description"] = method.long_desc
manifest = { manifest = {
'options': list(self.options.values()), 'options': list(self.options.values()),
@ -628,12 +713,30 @@ class Plugin(object):
return manifest return manifest
def _init(self, options, configuration, request): def _init(self, options: Dict[str, JSONType],
self.rpc_filename = configuration['rpc-file'] configuration: Dict[str, JSONType],
self.lightning_dir = configuration['lightning-dir'] request: Request) -> JSONType:
def verify_str(d: Dict[str, JSONType], key: str) -> str:
v = d.get(key)
if not isinstance(v, str):
raise ValueError("Wrong argument to init: expected {key} to be"
" a string, got {v}".format(key=key, v=v))
return v
def verify_bool(d: Dict[str, JSONType], key: str) -> bool:
v = d.get(key)
if not isinstance(v, bool):
raise ValueError("Wrong argument to init: expected {key} to be"
" a bool, got {v}".format(key=key, v=v))
return v
self.rpc_filename = verify_str(configuration, 'rpc-file')
self.lightning_dir = verify_str(configuration, 'lightning-dir')
path = os.path.join(self.lightning_dir, self.rpc_filename) path = os.path.join(self.lightning_dir, self.rpc_filename)
self.rpc = LightningRpc(path) self.rpc = LightningRpc(path)
self.startup = configuration['startup'] self.startup = verify_bool(configuration, 'startup')
for name, value in options.items(): for name, value in options.items():
self.options[name]['value'] = value self.options[name]['value'] = value
@ -647,18 +750,18 @@ class PluginStream(object):
"""Sink that turns everything that is written to it into a notification. """Sink that turns everything that is written to it into a notification.
""" """
def __init__(self, plugin, level="info"): def __init__(self, plugin: Plugin, level: str = "info"):
self.plugin = plugin self.plugin = plugin
self.level = level self.level = level
self.buff = '' self.buff = ''
def write(self, payload): def write(self, payload: str) -> None:
self.buff += payload self.buff += payload
if len(payload) > 0 and payload[-1] == '\n': if len(payload) > 0 and payload[-1] == '\n':
self.flush() self.flush()
def flush(self): def flush(self) -> None:
lines = self.buff.split('\n') lines = self.buff.split('\n')
if len(lines) < 2: if len(lines) < 2:
return return
@ -670,7 +773,8 @@ class PluginStream(object):
self.buff = lines[-1] self.buff = lines[-1]
def monkey_patch(plugin, stdout=True, stderr=False): def monkey_patch(plugin: Plugin, stdout: bool = True,
stderr: bool = False) -> None:
"""Monkey patch stderr and stdout so we use notifications instead. """Monkey patch stderr and stdout so we use notifications instead.
A plugin commonly communicates with lightningd over its stdout and A plugin commonly communicates with lightningd over its stdout and

79
contrib/pyln-proto/pyln/proto/primitives.py

@ -1,3 +1,4 @@
import coincurve
import struct import struct
@ -66,5 +67,79 @@ class ShortChannelId(object):
def __str__(self): def __str__(self):
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self) return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum if not isinstance(other, ShortChannelId):
return False
return (
self.block == other.block
and self.txnum == other.txnum
and self.outnum == other.outnum
)
class Secret(object):
def __init__(self, data: bytes) -> None:
assert(len(data) == 32)
self.data = data
def to_bytes(self) -> bytes:
return self.data
def __eq__(self, other: object) -> bool:
return isinstance(other, Secret) and self.data == other.data
def __str__(self):
return "Secret[0x{}]".format(self.data.hex())
class PrivateKey(object):
def __init__(self, rawkey) -> None:
if not isinstance(rawkey, bytes):
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
elif len(rawkey) != 32:
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
self.rawkey = rawkey
self.key = coincurve.PrivateKey(rawkey)
def serializeCompressed(self):
return self.key.secret
def public_key(self):
return PublicKey(self.key.public_key)
class PublicKey(object):
def __init__(self, innerkey):
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
# by coincurve
if isinstance(innerkey, bytes):
if innerkey[0] in [2, 3] and len(innerkey) == 33:
innerkey = coincurve.PublicKey(innerkey)
else:
raise ValueError(
"Byte keys must be 33-byte long starting from either 02 or 03"
)
elif not isinstance(innerkey, coincurve.keys.PublicKey):
raise ValueError(
"Key must either be bytes or coincurve.keys.PublicKey"
)
self.key = innerkey
def serializeCompressed(self):
return self.key.format(compressed=True)
def to_bytes(self) -> bytes:
return self.serializeCompressed()
def __str__(self):
return "PublicKey[0x{}]".format(
self.serializeCompressed().hex()
)
def Keypair(object):
def __init__(self, priv, pub):
self.priv, self.pub = priv, pub

71
contrib/pyln-proto/pyln/proto/wire.py

@ -4,6 +4,7 @@ from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from .primitives import Secret, PrivateKey, PublicKey
from hashlib import sha256 from hashlib import sha256
import coincurve import coincurve
import os import os
@ -55,64 +56,6 @@ def decryptWithAD(k, n, ad, ciphertext):
return chacha.decrypt(n, ciphertext, ad) return chacha.decrypt(n, ciphertext, ad)
class PrivateKey(object):
def __init__(self, rawkey):
if not isinstance(rawkey, bytes):
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
elif len(rawkey) != 32:
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
self.rawkey = rawkey
self.key = coincurve.PrivateKey(rawkey)
def serializeCompressed(self):
return self.key.secret
def public_key(self):
return PublicKey(self.key.public_key)
class Secret(object):
def __init__(self, raw):
assert(len(raw) == 32)
self.raw = raw
def __str__(self):
return "Secret[0x{}]".format(self.raw.hex())
class PublicKey(object):
def __init__(self, innerkey):
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
# by coincurve
if isinstance(innerkey, bytes):
if innerkey[0] in [2, 3] and len(innerkey) == 33:
innerkey = coincurve.PublicKey(innerkey)
else:
raise ValueError(
"Byte keys must be 33-byte long starting from either 02 or 03"
)
elif not isinstance(innerkey, coincurve.keys.PublicKey):
raise ValueError(
"Key must either be bytes or coincurve.keys.PublicKey"
)
self.key = innerkey
def serializeCompressed(self):
return self.key.format(compressed=True)
def __str__(self):
return "PublicKey[0x{}]".format(
self.serializeCompressed().hex()
)
def Keypair(object):
def __init__(self, priv, pub):
self.priv, self.pub = priv, pub
class Sha256Mixer(object): class Sha256Mixer(object):
def __init__(self, base): def __init__(self, base):
self.hash = sha256(base).digest() self.hash = sha256(base).digest()
@ -174,7 +117,7 @@ class LightningConnection(object):
h.hash = self.handshake['h'] h.hash = self.handshake['h']
h.update(self.handshake['e'].public_key().serializeCompressed()) h.update(self.handshake['e'].public_key().serializeCompressed())
es = ecdh(self.handshake['e'], self.remote_pubkey) es = ecdh(self.handshake['e'], self.remote_pubkey)
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'') t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
assert(len(t) == 64) assert(len(t) == 64)
self.chaining_key, temp_k1 = t[:32], t[32:] self.chaining_key, temp_k1 = t[:32], t[32:]
c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'') c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'')
@ -194,7 +137,7 @@ class LightningConnection(object):
h.update(re.serializeCompressed()) h.update(re.serializeCompressed())
es = ecdh(self.local_privkey, re) es = ecdh(self.local_privkey, re)
self.handshake['re'] = re self.handshake['re'] = re
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'') t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
self.chaining_key, temp_k1 = t[:32], t[32:] self.chaining_key, temp_k1 = t[:32], t[32:]
try: try:
@ -210,7 +153,7 @@ class LightningConnection(object):
h.hash = self.handshake['h'] h.hash = self.handshake['h']
h.update(self.handshake['e'].public_key().serializeCompressed()) h.update(self.handshake['e'].public_key().serializeCompressed())
ee = ecdh(self.handshake['e'], self.handshake['re']) ee = ecdh(self.handshake['e'], self.handshake['re'])
t = hkdf(salt=self.chaining_key, ikm=ee.raw, info=b'') t = hkdf(salt=self.chaining_key, ikm=ee.data, info=b'')
assert(len(t) == 64) assert(len(t) == 64)
self.chaining_key, self.temp_k2 = t[:32], t[32:] self.chaining_key, self.temp_k2 = t[:32], t[32:]
c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'') c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'')
@ -231,7 +174,7 @@ class LightningConnection(object):
h.update(re.serializeCompressed()) h.update(re.serializeCompressed())
ee = ecdh(self.handshake['e'], re) ee = ecdh(self.handshake['e'], re)
self.chaining_key, self.temp_k2 = hkdf_two_keys( self.chaining_key, self.temp_k2 = hkdf_two_keys(
salt=self.chaining_key, ikm=ee.raw salt=self.chaining_key, ikm=ee.data
) )
try: try:
decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c) decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c)
@ -249,7 +192,7 @@ class LightningConnection(object):
se = ecdh(self.local_privkey, self.re) se = ecdh(self.local_privkey, self.re)
self.chaining_key, self.temp_k3 = hkdf_two_keys( self.chaining_key, self.temp_k3 = hkdf_two_keys(
salt=self.chaining_key, ikm=se.raw salt=self.chaining_key, ikm=se.data
) )
t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'') t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'')
m = b'\x00' + c + t m = b'\x00' + c + t
@ -272,7 +215,7 @@ class LightningConnection(object):
se = ecdh(self.handshake['e'], self.remote_pubkey) se = ecdh(self.handshake['e'], self.remote_pubkey)
self.chaining_key, self.temp_k3 = hkdf_two_keys( self.chaining_key, self.temp_k3 = hkdf_two_keys(
se.raw, self.chaining_key se.data, self.chaining_key
) )
decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t) decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t)
self.rn, self.sn = 0, 0 self.rn, self.sn = 0, 0

Loading…
Cancel
Save