diff --git a/contrib/pylightning/lightning/plugin.py b/contrib/pylightning/lightning/plugin.py index d7b52dcd8..3da175d70 100644 --- a/contrib/pylightning/lightning/plugin.py +++ b/contrib/pylightning/lightning/plugin.py @@ -1,6 +1,6 @@ from collections import OrderedDict from enum import Enum -from lightning import LightningRpc +from lightning import LightningRpc, Millisatoshi from threading import RLock import inspect @@ -293,7 +293,11 @@ class Plugin(object): if isinstance(params, dict): for k, v in params.items(): if k in arguments: - arguments[k] = v + # Explicitly (try to) interpret as Millisatoshi if annotated + if func.__annotations__.get(k) == Millisatoshi: + arguments[k] = Millisatoshi(v) + else: + arguments[k] = v else: kwargs[k] = v else: @@ -305,7 +309,10 @@ class Plugin(object): if pos < len(params): # Apply positional args if we have them - arguments[k] = params[pos] + if func.__annotations__.get(k) == Millisatoshi: + arguments[k] = Millisatoshi(params[pos]) + else: + arguments[k] = params[pos] elif sig.parameters[k].default is inspect.Signature.empty: # This is a positional arg with no value passed raise TypeError("Missing required parameter: %s" % sig.parameters[k]) @@ -406,6 +413,8 @@ class Plugin(object): Returns the last partial message that was not complete yet. """ for payload in msgs[:-1]: + # Note that we use function annotations to do Millisatoshi conversions + # in _exec_func, so we don't use LightningJSONDecoder here. request = self._parse_request(json.loads(payload)) # If this has an 'id'-field, it's a request and returns a diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 5df0c47bb..5a5331dbf 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -34,7 +34,6 @@ def test_option_passthrough(node_factory): n.stop() -@pytest.mark.xfail(strict=True) def test_millisatoshi_passthrough(node_factory): """ Ensure that Millisatoshi arguments and return work. """