|
|
@ -29,16 +29,16 @@ import time |
|
|
|
import traceback |
|
|
|
import sys |
|
|
|
import threading |
|
|
|
from typing import Dict, Optional, Tuple, Iterable |
|
|
|
from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping |
|
|
|
from base64 import b64decode, b64encode |
|
|
|
from collections import defaultdict |
|
|
|
import concurrent |
|
|
|
from concurrent import futures |
|
|
|
import json |
|
|
|
|
|
|
|
import aiohttp |
|
|
|
from aiohttp import web, client_exceptions |
|
|
|
from aiorpcx import TaskGroup |
|
|
|
import json |
|
|
|
|
|
|
|
from . import util |
|
|
|
from .network import Network |
|
|
@ -151,6 +151,11 @@ class AuthenticatedServer(Logger): |
|
|
|
self.rpc_user = rpc_user |
|
|
|
self.rpc_password = rpc_password |
|
|
|
self.auth_lock = asyncio.Lock() |
|
|
|
self._methods = {} # type: Dict[str, Callable] |
|
|
|
|
|
|
|
def register_method(self, f): |
|
|
|
assert f.__name__ not in self._methods, f"name collision for {f.__name__}" |
|
|
|
self._methods[f.__name__] = f |
|
|
|
|
|
|
|
async def authenticate(self, headers): |
|
|
|
if self.rpc_password == '': |
|
|
@ -184,15 +189,21 @@ class AuthenticatedServer(Logger): |
|
|
|
request = json.loads(request) |
|
|
|
method = request['method'] |
|
|
|
_id = request['id'] |
|
|
|
params = request.get('params', []) |
|
|
|
f = getattr(self, method) |
|
|
|
assert f in self.methods |
|
|
|
except: |
|
|
|
params = request.get('params', []) # type: Union[Sequence, Mapping] |
|
|
|
if method not in self._methods: |
|
|
|
raise Exception(f"attempting to use unregistered method: {method}") |
|
|
|
f = self._methods[method] |
|
|
|
except Exception as e: |
|
|
|
self.logger.exception("invalid request") |
|
|
|
return web.Response(text='Invalid Request', status=500) |
|
|
|
response = {'id':_id} |
|
|
|
response = {'id': _id} |
|
|
|
try: |
|
|
|
response['result'] = await f(*params) |
|
|
|
if isinstance(params, dict): |
|
|
|
response['result'] = await f(**params) |
|
|
|
else: |
|
|
|
response['result'] = await f(*params) |
|
|
|
except BaseException as e: |
|
|
|
self.logger.exception("internal error while executing RPC") |
|
|
|
response['error'] = str(e) |
|
|
|
return web.json_response(response) |
|
|
|
|
|
|
@ -209,13 +220,12 @@ class CommandsServer(AuthenticatedServer): |
|
|
|
self.port = self.config.get('rpcport', 0) |
|
|
|
self.app = web.Application() |
|
|
|
self.app.router.add_post("/", self.handle) |
|
|
|
self.methods = set() |
|
|
|
self.methods.add(self.ping) |
|
|
|
self.methods.add(self.gui) |
|
|
|
self.register_method(self.ping) |
|
|
|
self.register_method(self.gui) |
|
|
|
self.cmd_runner = Commands(config=self.config, network=self.daemon.network, daemon=self.daemon) |
|
|
|
for cmdname in known_commands: |
|
|
|
self.methods.add(getattr(self.cmd_runner, cmdname)) |
|
|
|
self.methods.add(self.run_cmdline) |
|
|
|
self.register_method(getattr(self.cmd_runner, cmdname)) |
|
|
|
self.register_method(self.run_cmdline) |
|
|
|
|
|
|
|
async def run(self): |
|
|
|
self.runner = web.AppRunner(self.app) |
|
|
@ -277,9 +287,8 @@ class WatchTowerServer(AuthenticatedServer): |
|
|
|
self.lnwatcher = network.local_watchtower |
|
|
|
self.app = web.Application() |
|
|
|
self.app.router.add_post("/", self.handle) |
|
|
|
self.methods = set() |
|
|
|
self.methods.add(self.get_ctn) |
|
|
|
self.methods.add(self.add_sweep_tx) |
|
|
|
self.register_method(self.get_ctn) |
|
|
|
self.register_method(self.add_sweep_tx) |
|
|
|
|
|
|
|
async def run(self): |
|
|
|
self.runner = web.AppRunner(self.app) |
|
|
|