Browse Source

pylightning: Wrap request in an object

We well need this in the next commit to be able to return from an
asynchronous call. We also guard stdout access with a reentrant lock
since we are no longer guaranteed that all communication happens on
the same thread.

Signed-off-by: Christian Decker <decker.christian@gmail.com>
pr-2355-addendum
Christian Decker 6 years ago
committed by Rusty Russell
parent
commit
81fa247d07
  1. 116
      contrib/pylightning/lightning/plugin.py
  2. 27
      contrib/pylightning/lightning/test_plugin.py

116
contrib/pylightning/lightning/plugin.py

@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from lightning import LightningRpc
from enum import Enum from enum import Enum
from lightning import LightningRpc
from threading import RLock
import inspect import inspect
import json import json
@ -15,6 +16,12 @@ class MethodType(Enum):
HOOK = 1 HOOK = 1
class RequestState(Enum):
PENDING = 'pending'
FINISHED = 'finished'
FAILED = 'failed'
class Method(object): class Method(object):
"""Description of methods that are registered with the plugin. """Description of methods that are registered with the plugin.
@ -27,6 +34,59 @@ class Method(object):
self.name = name self.name = name
self.func = func self.func = func
self.mtype = mtype self.mtype = mtype
self.background = False
class Request(dict):
"""A request object that wraps params and allows async return
"""
def __init__(self, plugin, req_id, method, params, background=False):
self.method = method
self.params = params
self.background = background
self.plugin = plugin
self.state = RequestState.PENDING
self.id = req_id
def getattr(self, key):
if key == "params":
return self.params
elif key == "id":
return self.id
elif key == "method":
return self.method
def set_result(self, result):
if self.state != RequestState.PENDING:
raise ValueError(
"Cannot set the result of a request that is not pending, "
"current state is {state}".format(self.state))
self.result = result
self._write_result({
'jsonrpc': '2.0',
'id': self.id,
'result': self.result
})
def set_exception(self, exc):
if self.state != RequestState.PENDING:
raise ValueError(
"Cannot set the exception of a request that is not pending, "
"current state is {state}".format(self.state))
self.exc = exc
self._write_result({
'jsonrpc': '2.0',
'id': self.id,
"error": "Error while processing {method}: {exc}".format(
method=self.method, exc=repr(exc)
),
})
def _write_result(self, result):
with self.plugin.write_lock:
json.dump(result, fp=self.plugin.stdout)
self.plugin.stdout.write('\n\n')
self.plugin.stdout.flush()
class Plugin(object): class Plugin(object):
@ -59,6 +119,8 @@ class Plugin(object):
self.rpc = None self.rpc = None
self.child_init = None self.child_init = None
self.write_lock = RLock()
def add_method(self, name, func): def add_method(self, name, func):
"""Add a plugin method to the dispatch table. """Add a plugin method to the dispatch table.
@ -185,7 +247,7 @@ class Plugin(object):
return decorator return decorator
def _exec_func(self, func, request): def _exec_func(self, func, request):
params = request['params'] params = request.params
sig = inspect.signature(func) sig = inspect.signature(func)
arguments = OrderedDict() arguments = OrderedDict()
@ -223,36 +285,30 @@ class Plugin(object):
return func(*ba.args, **ba.kwargs) return func(*ba.args, **ba.kwargs)
def _dispatch_request(self, request): def _dispatch_request(self, request):
name = request['method'] name = request.method
if name not in self.methods: if name not in self.methods:
raise ValueError("No method {} found.".format(name)) raise ValueError("No method {} found.".format(name))
method = self.methods[name] method = self.methods[name]
request.background = method.background
try: try:
result = { result = self._exec_func(method.func, request)
'jsonrpc': '2.0', if not method.background:
'id': request['id'], # Only if this is not an async (background) call do we need to
'result': self._exec_func(method.func, request) # return the result, otherwise the callee will eventually need
} # to call request.set_result or request.set_exception to
# return a result or raise an exception.
request.set_result(result)
except Exception as e: except Exception as e:
result = { request.set_exception(e)
'jsonrpc': '2.0',
'id': request['id'],
"error": "Error while processing {}: {}".format(
request['method'], repr(e)
),
}
self.log(traceback.format_exc()) self.log(traceback.format_exc())
json.dump(result, fp=self.stdout)
self.stdout.write('\n\n')
self.stdout.flush()
def _dispatch_notification(self, request): def _dispatch_notification(self, request):
name = request['method'] if request.method not in self.subscriptions:
if name not in self.subscriptions: raise ValueError("No subscription for {name} found.".format(
raise ValueError("No subscription for {} found.".format(name)) name=request.method))
func = self.subscriptions[name] func = self.subscriptions[request.method]
try: try:
self._exec_func(func, request) self._exec_func(func, request)
@ -265,9 +321,10 @@ class Plugin(object):
'method': method, 'method': method,
'params': params, 'params': params,
} }
json.dump(payload, self.stdout) with self.write_lock:
self.stdout.write("\n\n") json.dump(payload, self.stdout)
self.stdout.flush() self.stdout.write("\n\n")
self.stdout.flush()
def log(self, message, level='info'): def log(self, message, level='info'):
# Split the log into multiple lines and print them # Split the log into multiple lines and print them
@ -282,11 +339,18 @@ class Plugin(object):
""" """
for payload in msgs[:-1]: for payload in msgs[:-1]:
request = json.loads(payload) request = json.loads(payload)
request = Request(
plugin=self,
req_id=request.get('id', None),
method=request['method'],
params=request['params'],
background=False,
)
# If this has an 'id'-field, it's a request and returns a # If this has an 'id'-field, it's a request and returns a
# result. Otherwise it's a notification and it doesn't # result. Otherwise it's a notification and it doesn't
# return anything. # return anything.
if 'id' in request: if request.id is not None:
self._dispatch_request(request) self._dispatch_request(request)
else: else:
self._dispatch_notification(request) self._dispatch_notification(request)

27
contrib/pylightning/lightning/test_plugin.py

@ -1,22 +1,21 @@
from .plugin import Plugin from .plugin import Plugin, Request
import itertools import itertools
def test_positional_inject(): def test_positional_inject():
p = Plugin() p = Plugin()
rdict = { rdict = Request(
'id': 1, plugin=p,
'jsonrpc': req_id=1,
'2.0', method='func',
'method': 'func', params={'a': 1, 'b': 2, 'kwa': 3, 'kwb': 4}
'params': {'a': 1, 'b': 2, 'kwa': 3, 'kwb': 4} )
} rarr = Request(
rarr = { plugin=p,
'id': 1, req_id=1,
'jsonrpc': '2.0', method='func',
'method': 'func', params=[1, 2, 3, 4],
'params': [1, 2, 3, 4] )
}
def pre_args(plugin, a, b, kwa=3, kwb=4): def pre_args(plugin, a, b, kwa=3, kwb=4):
assert (plugin, a, b, kwa, kwb) == (p, 1, 2, 3, 4) assert (plugin, a, b, kwa, kwb) == (p, 1, 2, 3, 4)

Loading…
Cancel
Save