diff --git a/contrib/pyln-client/pyln/client/lightning.py b/contrib/pyln-client/pyln/client/lightning.py index 1ac25530a..096809e38 100644 --- a/contrib/pyln-client/pyln/client/lightning.py +++ b/contrib/pyln-client/pyln/client/lightning.py @@ -5,6 +5,21 @@ import logging import os import socket import warnings +from json import JSONEncoder + + +def _patched_default(self, obj): + return getattr(obj.__class__, "to_json", _patched_default.default)(obj) + + +def monkey_patch_json(patch=True): + is_patched = JSONEncoder.default == _patched_default + + if patch and not is_patched: + _patched_default.default = JSONEncoder.default # Save unmodified + JSONEncoder.default = _patched_default # Replace it. + elif not patch and is_patched: + JSONEncoder.default = _patched_default.default class RpcError(ValueError): @@ -327,7 +342,10 @@ class LightningRpc(UnixDomainSocketRpc): return json.JSONEncoder.default(self, o) class LightningJSONDecoder(json.JSONDecoder): - def __init__(self, *, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=True, object_pairs_hook=None): + def __init__(self, *, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, + strict=True, object_pairs_hook=None, + patch_json=True): self.object_hook_next = object_hook super().__init__(object_hook=self.millisatoshi_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, strict=strict, object_pairs_hook=object_pairs_hook) @@ -357,8 +375,18 @@ class LightningRpc(UnixDomainSocketRpc): obj = self.object_hook_next(obj) return obj - def __init__(self, socket_path, executor=None, logger=logging): - super().__init__(socket_path, executor, logger, self.LightningJSONEncoder, self.LightningJSONDecoder()) + def __init__(self, socket_path, executor=None, logger=logging, + patch_json=True): + super().__init__( + socket_path, + executor, + logger, + self.LightningJSONEncoder, + self.LightningJSONDecoder() + ) + + if patch_json: + monkey_patch_json(patch=True) def autocleaninvoice(self, cycle_seconds=None, expired_by=None): """