diff --git a/contrib/pylightning/lightning/plugin.py b/contrib/pylightning/lightning/plugin.py index d0d7cec41..f953af518 100644 --- a/contrib/pylightning/lightning/plugin.py +++ b/contrib/pylightning/lightning/plugin.py @@ -1,6 +1,7 @@ from collections import OrderedDict -from lightning import LightningRpc from enum import Enum +from lightning import LightningRpc +from threading import RLock import inspect import json @@ -15,6 +16,12 @@ class MethodType(Enum): HOOK = 1 +class RequestState(Enum): + PENDING = 'pending' + FINISHED = 'finished' + FAILED = 'failed' + + class Method(object): """Description of methods that are registered with the plugin. @@ -27,6 +34,59 @@ class Method(object): self.name = name self.func = func self.mtype = mtype + self.background = False + + +class Request(dict): + """A request object that wraps params and allows async return + """ + def __init__(self, plugin, req_id, method, params, background=False): + self.method = method + self.params = params + self.background = background + self.plugin = plugin + self.state = RequestState.PENDING + self.id = req_id + + def getattr(self, key): + if key == "params": + return self.params + elif key == "id": + return self.id + elif key == "method": + return self.method + + def set_result(self, result): + if self.state != RequestState.PENDING: + raise ValueError( + "Cannot set the result of a request that is not pending, " + "current state is {state}".format(self.state)) + self.result = result + self._write_result({ + 'jsonrpc': '2.0', + 'id': self.id, + 'result': self.result + }) + + def set_exception(self, exc): + if self.state != RequestState.PENDING: + raise ValueError( + "Cannot set the exception of a request that is not pending, " + "current state is {state}".format(self.state)) + self.exc = exc + self._write_result({ + 'jsonrpc': '2.0', + 'id': self.id, + "error": "Error while processing {method}: {exc}".format( + method=self.method, exc=repr(exc) + ), + }) + + def _write_result(self, result): + with self.plugin.write_lock: + json.dump(result, fp=self.plugin.stdout) + self.plugin.stdout.write('\n\n') + self.plugin.stdout.flush() class Plugin(object): @@ -59,6 +119,8 @@ class Plugin(object): self.rpc = None self.child_init = None + self.write_lock = RLock() + def add_method(self, name, func): """Add a plugin method to the dispatch table. @@ -185,7 +247,7 @@ class Plugin(object): return decorator def _exec_func(self, func, request): - params = request['params'] + params = request.params sig = inspect.signature(func) arguments = OrderedDict() @@ -223,36 +285,30 @@ class Plugin(object): return func(*ba.args, **ba.kwargs) def _dispatch_request(self, request): - name = request['method'] + name = request.method if name not in self.methods: raise ValueError("No method {} found.".format(name)) method = self.methods[name] + request.background = method.background try: - result = { - 'jsonrpc': '2.0', - 'id': request['id'], - 'result': self._exec_func(method.func, request) - } + result = self._exec_func(method.func, request) + if not method.background: + # Only if this is not an async (background) call do we need to + # return the result, otherwise the callee will eventually need + # to call request.set_result or request.set_exception to + # return a result or raise an exception. + request.set_result(result) except Exception as e: - result = { - 'jsonrpc': '2.0', - 'id': request['id'], - "error": "Error while processing {}: {}".format( - request['method'], repr(e) - ), - } + request.set_exception(e) self.log(traceback.format_exc()) - json.dump(result, fp=self.stdout) - self.stdout.write('\n\n') - self.stdout.flush() def _dispatch_notification(self, request): - name = request['method'] - if name not in self.subscriptions: - raise ValueError("No subscription for {} found.".format(name)) - func = self.subscriptions[name] + if request.method not in self.subscriptions: + raise ValueError("No subscription for {name} found.".format( + name=request.method)) + func = self.subscriptions[request.method] try: self._exec_func(func, request) @@ -265,9 +321,10 @@ class Plugin(object): 'method': method, 'params': params, } - json.dump(payload, self.stdout) - self.stdout.write("\n\n") - self.stdout.flush() + with self.write_lock: + json.dump(payload, self.stdout) + self.stdout.write("\n\n") + self.stdout.flush() def log(self, message, level='info'): # Split the log into multiple lines and print them @@ -282,11 +339,18 @@ class Plugin(object): """ for payload in msgs[:-1]: request = json.loads(payload) + request = Request( + plugin=self, + req_id=request.get('id', None), + method=request['method'], + params=request['params'], + background=False, + ) # If this has an 'id'-field, it's a request and returns a # result. Otherwise it's a notification and it doesn't # return anything. - if 'id' in request: + if request.id is not None: self._dispatch_request(request) else: self._dispatch_notification(request) diff --git a/contrib/pylightning/lightning/test_plugin.py b/contrib/pylightning/lightning/test_plugin.py index 07d2c8ca2..0c90dea9f 100644 --- a/contrib/pylightning/lightning/test_plugin.py +++ b/contrib/pylightning/lightning/test_plugin.py @@ -1,22 +1,21 @@ -from .plugin import Plugin +from .plugin import Plugin, Request import itertools def test_positional_inject(): p = Plugin() - rdict = { - 'id': 1, - 'jsonrpc': - '2.0', - 'method': 'func', - 'params': {'a': 1, 'b': 2, 'kwa': 3, 'kwb': 4} - } - rarr = { - 'id': 1, - 'jsonrpc': '2.0', - 'method': 'func', - 'params': [1, 2, 3, 4] - } + rdict = Request( + plugin=p, + req_id=1, + method='func', + params={'a': 1, 'b': 2, 'kwa': 3, 'kwb': 4} + ) + rarr = Request( + plugin=p, + req_id=1, + method='func', + params=[1, 2, 3, 4], + ) def pre_args(plugin, a, b, kwa=3, kwb=4): assert (plugin, a, b, kwa, kwb) == (p, 1, 2, 3, 4)