From 46de76dcf2863fe71d71512268851e52d0078d51 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Thu, 29 Aug 2019 17:44:12 +0100 Subject: [PATCH] daemon.py: share a single session/connection across requests This should scale better --- electrumx/server/controller.py | 70 +++++------ electrumx/server/daemon.py | 39 ++++--- tests/server/test_daemon.py | 206 +++++++++++++++------------------ 3 files changed, 152 insertions(+), 163 deletions(-) diff --git a/electrumx/server/controller.py b/electrumx/server/controller.py index 2e7dc01..3581409 100644 --- a/electrumx/server/controller.py +++ b/electrumx/server/controller.py @@ -97,38 +97,38 @@ class Controller(ServerBase): Daemon = env.coin.DAEMON BlockProcessor = env.coin.BLOCK_PROCESSOR - daemon = Daemon(env.coin, env.daemon_url) - db = DB(env) - bp = BlockProcessor(env, db, daemon, notifications) - - # Set notifications up to implement the MemPoolAPI - def get_db_height(): - return db.db_height - notifications.height = daemon.height - notifications.db_height = get_db_height - notifications.cached_height = daemon.cached_height - notifications.mempool_hashes = daemon.mempool_hashes - notifications.raw_transactions = daemon.getrawtransactions - notifications.lookup_utxos = db.lookup_utxos - MemPoolAPI.register(Notifications) - mempool = MemPool(env.coin, notifications) - - session_mgr = SessionManager(env, db, bp, daemon, mempool, - shutdown_event) - - # Test daemon authentication, and also ensure it has a cached - # height. Do this before entering the task group. - await daemon.height() - - caught_up_event = Event() - mempool_event = Event() - - async def wait_for_catchup(): - await caught_up_event.wait() - await group.spawn(db.populate_header_merkle_cache()) - await group.spawn(mempool.keep_synchronized(mempool_event)) - - async with TaskGroup() as group: - await group.spawn(session_mgr.serve(notifications, mempool_event)) - await group.spawn(bp.fetch_and_process_blocks(caught_up_event)) - await group.spawn(wait_for_catchup()) + async with Daemon(env.coin, env.daemon_url) as daemon: + db = DB(env) + bp = BlockProcessor(env, db, daemon, notifications) + + # Set notifications up to implement the MemPoolAPI + def get_db_height(): + return db.db_height + notifications.height = daemon.height + notifications.db_height = get_db_height + notifications.cached_height = daemon.cached_height + notifications.mempool_hashes = daemon.mempool_hashes + notifications.raw_transactions = daemon.getrawtransactions + notifications.lookup_utxos = db.lookup_utxos + MemPoolAPI.register(Notifications) + mempool = MemPool(env.coin, notifications) + + session_mgr = SessionManager(env, db, bp, daemon, mempool, + shutdown_event) + + # Test daemon authentication, and also ensure it has a cached + # height. Do this before entering the task group. + await daemon.height() + + caught_up_event = Event() + mempool_event = Event() + + async def wait_for_catchup(): + await caught_up_event.wait() + await group.spawn(db.populate_header_merkle_cache()) + await group.spawn(mempool.keep_synchronized(mempool_event)) + + async with TaskGroup() as group: + await group.spawn(session_mgr.serve(notifications, mempool_event)) + await group.spawn(bp.fetch_and_process_blocks(caught_up_event)) + await group.spawn(wait_for_catchup()) diff --git a/electrumx/server/daemon.py b/electrumx/server/daemon.py index dae0133..ff6366a 100644 --- a/electrumx/server/daemon.py +++ b/electrumx/server/daemon.py @@ -43,7 +43,7 @@ class Daemon(object): WARMING_UP = -28 id_counter = itertools.count() - def __init__(self, coin, url, max_workqueue=10, init_retry=0.25, max_retry=4.0): + def __init__(self, coin, url, *, max_workqueue=10, init_retry=0.25, max_retry=4.0): self.coin = coin self.logger = class_logger(__name__, self.__class__.__name__) self.url_index = None @@ -56,6 +56,18 @@ class Daemon(object): self.max_retry = max_retry self._height = None self.available_rpcs = {} + self.session = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession(connector=self.connector()) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.session.close() + self.session = None + + def connector(self): + return None def set_url(self, url): '''Set the URLS to the given list, and switch to the first one.''' @@ -88,21 +100,15 @@ class Daemon(object): return True return False - def client_session(self): - '''An aiohttp client session.''' - return aiohttp.ClientSession() - async def _send_data(self, data): async with self.workqueue_semaphore: - async with self.client_session() as session: - async with session.post(self.current_url(), data=data) as resp: - kind = resp.headers.get('Content-Type', None) - if kind == 'application/json': - return await resp.json() - # bitcoind's HTTP protocol "handling" is a bad joke - text = await resp.text() - text = text.strip() or resp.reason - raise ServiceRefusedError(text) + async with self.session.post(self.current_url(), data=data) as resp: + kind = resp.headers.get('Content-Type', None) + if kind == 'application/json': + return await resp.json() + text = await resp.text() + text = text.strip() or resp.reason + raise ServiceRefusedError(text) async def _send(self, payload, processor): '''Send a payload to be converted to JSON. @@ -446,10 +452,9 @@ class DecredDaemon(Daemon): mempool += tip.get('stx', []) return mempool - def client_session(self): + def connector(self): # FIXME allow self signed certificates - connector = aiohttp.TCPConnector(verify_ssl=False) - return aiohttp.ClientSession(connector=connector) + return aiohttp.TCPConnector(verify_ssl=False) class PreLegacyRPCDaemon(LegacyRPCDaemon): diff --git a/tests/server/test_daemon.py b/tests/server/test_daemon.py index 4e29751..74c701a 100644 --- a/tests/server/test_daemon.py +++ b/tests/server/test_daemon.py @@ -76,23 +76,7 @@ class HTMLResponse(ResponseBase): return self._text -class ClientSessionBase(object): - - def __enter__(self): - self.prior_class = aiohttp.ClientSession - aiohttp.ClientSession = lambda: self - - def __exit__(self, exc_type, exc_value, traceback): - aiohttp.ClientSession = self.prior_class - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - pass - - -class ClientSessionGood(ClientSessionBase): +class ClientSessionGood: '''Imitate aiohttp for testing purposes.''' def __init__(self, *triples): @@ -120,7 +104,7 @@ class ClientSessionGood(ClientSessionBase): return JSONResponse(result, request_ids) -class ClientSessionBadAuth(ClientSessionBase): +class ClientSessionBadAuth: def post(self, url, data=""): return HTMLResponse('', 'Unauthorized', 401) @@ -134,25 +118,18 @@ class ClientSessionWorkQueueFull(ClientSessionGood): 'Internal server error', 500) -class ClientSessionNoConnection(ClientSessionGood): - - def __init__(self, *args): - self.args = args - - async def __aenter__(self): - aiohttp.ClientSession = lambda: ClientSessionGood(*self.args) - raise aiohttp.ClientConnectionError - - class ClientSessionPostError(ClientSessionGood): def __init__(self, exception, *args): + super().__init__(*args) self.exception = exception - self.args = args + self.n = 0 def post(self, url, data=""): - aiohttp.ClientSession = lambda: ClientSessionGood(*self.args) - raise self.exception + self.n += 1 + if self.n == 1: + raise self.exception + return super().post(url, data) class ClientSessionFailover(ClientSessionGood): @@ -174,35 +151,39 @@ def in_caplog(caplog, message, count=1): # Tests # -def test_set_urls_bad(): +@pytest.mark.asyncio +async def test_set_urls_bad(): with pytest.raises(CoinError): Daemon(coin, '') with pytest.raises(CoinError): Daemon(coin, 'a') -def test_set_urls_one(caplog): +@pytest.mark.asyncio +async def test_set_urls_one(caplog): with caplog.at_level(logging.INFO): daemon = Daemon(coin, urls[0]) - assert daemon.current_url() == urls[0] - assert len(daemon.urls) == 1 - logged_url = daemon.logged_url() - assert logged_url == '127.0.0.1:8332/' - assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)') + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 1 + logged_url = daemon.logged_url() + assert logged_url == '127.0.0.1:8332/' + assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)') -def test_set_urls_two(caplog): +@pytest.mark.asyncio +async def test_set_urls_two(caplog): with caplog.at_level(logging.INFO): daemon = Daemon(coin, ','.join(urls)) - assert daemon.current_url() == urls[0] - assert len(daemon.urls) == 2 - logged_url = daemon.logged_url() - assert logged_url == '127.0.0.1:8332/' - assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)') - assert in_caplog(caplog, 'daemon #2 at 192.168.0.1:8332') + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 2 + logged_url = daemon.logged_url() + assert logged_url == '127.0.0.1:8332/' + assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)') + assert in_caplog(caplog, 'daemon #2 at 192.168.0.1:8332') -def test_set_urls_short(): +@pytest.mark.asyncio +async def test_set_urls_short(): no_prefix_urls = ['/'.join(part for part in url.split('/')[2:]) for url in urls] daemon = Daemon(coin, ','.join(no_prefix_urls)) @@ -220,7 +201,8 @@ def test_set_urls_short(): assert len(daemon.urls) == 2 -def test_failover_good(caplog): +@pytest.mark.asyncio +async def test_failover_good(caplog): daemon = Daemon(coin, ','.join(urls)) with caplog.at_level(logging.INFO): result = daemon.failover() @@ -234,7 +216,8 @@ def test_failover_good(caplog): assert daemon.current_url() == urls[0] -def test_failover_fail(caplog): +@pytest.mark.asyncio +async def test_failover_fail(caplog): daemon = Daemon(coin, urls[0]) with caplog.at_level(logging.INFO): result = daemon.failover() @@ -247,8 +230,8 @@ def test_failover_fail(caplog): async def test_height(daemon): assert daemon.cached_height() is None height = 300 - with ClientSessionGood(('getblockcount', [], height)): - assert await daemon.height() == height + daemon.session = ClientSessionGood(('getblockcount', [], height)) + assert await daemon.height() == height assert daemon.cached_height() == height @@ -256,15 +239,15 @@ async def test_height(daemon): async def test_broadcast_transaction(daemon): raw_tx = 'deadbeef' tx_hash = 'hash' - with ClientSessionGood(('sendrawtransaction', [raw_tx], tx_hash)): - assert await daemon.broadcast_transaction(raw_tx) == tx_hash + daemon.session = ClientSessionGood(('sendrawtransaction', [raw_tx], tx_hash)) + assert await daemon.broadcast_transaction(raw_tx) == tx_hash @pytest.mark.asyncio async def test_relayfee(daemon): response = {"relayfee": sats, "other:": "cruft"} - with ClientSessionGood(('getnetworkinfo', [], response)): - assert await daemon.getnetworkinfo() == response + daemon.session = ClientSessionGood(('getnetworkinfo', [], response)) + assert await daemon.getnetworkinfo() == response @pytest.mark.asyncio @@ -274,23 +257,23 @@ async def test_relayfee(daemon): else: sats = 2 response = {"relayfee": sats, "other:": "cruft"} - with ClientSessionGood(('getnetworkinfo', [], response)): - assert await daemon.relayfee() == sats + daemon.session = ClientSessionGood(('getnetworkinfo', [], response)) + assert await daemon.relayfee() == sats @pytest.mark.asyncio async def test_mempool_hashes(daemon): hashes = ['hex_hash1', 'hex_hash2'] - with ClientSessionGood(('getrawmempool', [], hashes)): - assert await daemon.mempool_hashes() == hashes + daemon.session = ClientSessionGood(('getrawmempool', [], hashes)) + assert await daemon.mempool_hashes() == hashes @pytest.mark.asyncio async def test_deserialised_block(daemon): block_hash = 'block_hash' result = {'some': 'mess'} - with ClientSessionGood(('getblock', [block_hash, True], result)): - assert await daemon.deserialised_block(block_hash) == result + daemon.session = ClientSessionGood(('getblock', [block_hash, True], result)) + assert await daemon.deserialised_block(block_hash) == result @pytest.mark.asyncio @@ -300,11 +283,11 @@ async def test_estimatefee(daemon): result = daemon.coin.ESTIMATE_FEE else: result = -1 - with ClientSessionGood( + daemon.session = ClientSessionGood( ('estimatesmartfee', [], method_not_found), ('estimatefee', [2], result) - ): - assert await daemon.estimatefee(2) == result + ) + assert await daemon.estimatefee(2) == result @pytest.mark.asyncio @@ -314,15 +297,15 @@ async def test_estimatefee_smart(daemon): return rate = 0.0002 result = {'feerate': rate} - with ClientSessionGood( - ('estimatesmartfee', [], bad_args), - ('estimatesmartfee', [2], result) - ): - assert await daemon.estimatefee(2) == rate + daemon.session = ClientSessionGood( + ('estimatesmartfee', [], bad_args), + ('estimatesmartfee', [2], result) + ) + assert await daemon.estimatefee(2) == rate # Test the rpc_available_cache is used - with ClientSessionGood(('estimatesmartfee', [2], result)): - assert await daemon.estimatefee(2) == rate + daemon.session = ClientSessionGood(('estimatesmartfee', [2], result)) + assert await daemon.estimatefee(2) == rate @pytest.mark.asyncio @@ -331,20 +314,20 @@ async def test_getrawtransaction(daemon): simple = 'tx_in_hex' verbose = {'hex': hex_hash, 'other': 'cruft'} # Test False is converted to 0 - old daemon's reject False - with ClientSessionGood(('getrawtransaction', [hex_hash, 0], simple)): - assert await daemon.getrawtransaction(hex_hash) == simple + daemon.session = ClientSessionGood(('getrawtransaction', [hex_hash, 0], simple)) + assert await daemon.getrawtransaction(hex_hash) == simple # Test True is converted to 1 - with ClientSessionGood(('getrawtransaction', [hex_hash, 1], verbose)): - assert await daemon.getrawtransaction( - hex_hash, True) == verbose + daemon.session = ClientSessionGood(('getrawtransaction', [hex_hash, 1], verbose)) + assert await daemon.getrawtransaction( + hex_hash, True) == verbose @pytest.mark.asyncio async def test_protx(dash_daemon): protx_hash = 'deadbeaf' - with ClientSessionGood(('protx', ['info', protx_hash], {})): - assert await dash_daemon.protx(['info', protx_hash]) == {} + dash_daemon.session = ClientSessionGood(('protx', ['info', protx_hash], {})) + assert await dash_daemon.protx(['info', protx_hash]) == {} # Batch tests @@ -353,8 +336,8 @@ async def test_protx(dash_daemon): async def test_empty_send(daemon): first = 5 count = 0 - with ClientSessionGood(('getblockhash', [], [])): - assert await daemon.block_hex_hashes(first, count) == [] + daemon.session = ClientSessionGood(('getblockhash', [], [])) + assert await daemon.block_hex_hashes(first, count) == [] @pytest.mark.asyncio @@ -362,10 +345,10 @@ async def test_block_hex_hashes(daemon): first = 5 count = 3 hashes = [f'hex_hash{n}' for n in range(count)] - with ClientSessionGood(('getblockhash', - [[n] for n in range(first, first + count)], - hashes)): - assert await daemon.block_hex_hashes(first, count) == hashes + daemon.session = ClientSessionGood(('getblockhash', + [[n] for n in range(first, first + count)], + hashes)) + assert await daemon.block_hex_hashes(first, count) == hashes @pytest.mark.asyncio @@ -376,8 +359,8 @@ async def test_raw_blocks(daemon): iterable = (hex_hash for hex_hash in hex_hashes) blocks = ["00", "019a", "02fe"] blocks_raw = [bytes.fromhex(block) for block in blocks] - with ClientSessionGood(('getblock', args_list, blocks)): - assert await daemon.raw_blocks(iterable) == blocks_raw + daemon.session = ClientSessionGood(('getblock', args_list, blocks)) + assert await daemon.raw_blocks(iterable) == blocks_raw @pytest.mark.asyncio @@ -387,15 +370,15 @@ async def test_get_raw_transactions(daemon): raw_txs_hex = ['fffefdfc', '0a0b0c0d'] raw_txs = [bytes.fromhex(raw_tx) for raw_tx in raw_txs_hex] # Test 0 - old daemon's reject False - with ClientSessionGood(('getrawtransaction', args_list, raw_txs_hex)): - assert await daemon.getrawtransactions(hex_hashes) == raw_txs + daemon.session = ClientSessionGood(('getrawtransaction', args_list, raw_txs_hex)) + assert await daemon.getrawtransactions(hex_hashes) == raw_txs # Test one error tx_not_found = RPCError(-1, 'some error message') results = ['ff0b7d', tx_not_found] raw_txs = [bytes.fromhex(results[0]), None] - with ClientSessionGood(('getrawtransaction', args_list, results)): - assert await daemon.getrawtransactions(hex_hashes) == raw_txs + daemon.session = ClientSessionGood(('getrawtransaction', args_list, results)) + assert await daemon.getrawtransactions(hex_hashes) == raw_txs # Other tests @@ -403,8 +386,8 @@ async def test_get_raw_transactions(daemon): @pytest.mark.asyncio async def test_bad_auth(daemon, caplog): async with ignore_after(0.1): - with ClientSessionBadAuth(): - await daemon.height() + daemon.session = ClientSessionBadAuth() + await daemon.height() assert in_caplog(caplog, "daemon service refused") assert in_caplog(caplog, "Unauthorized") @@ -415,8 +398,8 @@ async def test_workqueue_depth(daemon, caplog): daemon.init_retry = 0.01 height = 125 with caplog.at_level(logging.INFO): - with ClientSessionWorkQueueFull(('getblockcount', [], height)): - await daemon.height() == height + daemon.session = ClientSessionWorkQueueFull(('getblockcount', [], height)) + await daemon.height() == height assert in_caplog(caplog, "Work queue depth exceeded") assert in_caplog(caplog, "running normally") @@ -427,8 +410,9 @@ async def test_connection_error(daemon, caplog): height = 100 daemon.init_retry = 0.01 with caplog.at_level(logging.INFO): - with ClientSessionNoConnection(('getblockcount', [], height)): - await daemon.height() == height + daemon.session = ClientSessionPostError(aiohttp.ClientConnectionError, + ('getblockcount', [], height)) + await daemon.height() == height assert in_caplog(caplog, "connection problem - check your daemon is running") assert in_caplog(caplog, "connection restored") @@ -439,9 +423,9 @@ async def test_timeout_error(daemon, caplog): height = 100 daemon.init_retry = 0.01 with caplog.at_level(logging.INFO): - with ClientSessionPostError(asyncio.TimeoutError, - ('getblockcount', [], height)): - await daemon.height() == height + daemon.session = ClientSessionPostError(asyncio.TimeoutError, + ('getblockcount', [], height)) + await daemon.height() == height assert in_caplog(caplog, "timeout error") @@ -451,9 +435,9 @@ async def test_disconnected(daemon, caplog): height = 100 daemon.init_retry = 0.01 with caplog.at_level(logging.INFO): - with ClientSessionPostError(aiohttp.ServerDisconnectedError, - ('getblockcount', [], height)): - await daemon.height() == height + daemon.session = ClientSessionPostError(aiohttp.ServerDisconnectedError, + ('getblockcount', [], height)) + await daemon.height() == height assert in_caplog(caplog, "disconnected") assert in_caplog(caplog, "connection restored") @@ -465,11 +449,11 @@ async def test_warming_up(daemon, caplog): height = 100 daemon.init_retry = 0.01 with caplog.at_level(logging.INFO): - with ClientSessionGood( - ('getblockcount', [], warming_up), - ('getblockcount', [], height) - ): - assert await daemon.height() == height + daemon.session = ClientSessionGood( + ('getblockcount', [], warming_up), + ('getblockcount', [], height) + ) + assert await daemon.height() == height assert in_caplog(caplog, "starting up checking blocks") assert in_caplog(caplog, "running normally") @@ -483,9 +467,9 @@ async def test_warming_up_batch(daemon, caplog): daemon.init_retry = 0.01 hashes = ['hex_hash5'] with caplog.at_level(logging.INFO): - with ClientSessionGood(('getblockhash', [[first]], [warming_up]), - ('getblockhash', [[first]], hashes)): - assert await daemon.block_hex_hashes(first, count) == hashes + daemon.session = ClientSessionGood(('getblockhash', [[first]], [warming_up]), + ('getblockhash', [[first]], hashes)) + assert await daemon.block_hex_hashes(first, count) == hashes assert in_caplog(caplog, "starting up checking blocks") assert in_caplog(caplog, "running normally") @@ -497,8 +481,8 @@ async def test_failover(daemon, caplog): daemon.init_retry = 0.01 daemon.max_retry = 0.04 with caplog.at_level(logging.INFO): - with ClientSessionFailover(('getblockcount', [], height)): - await daemon.height() == height + daemon.session = ClientSessionFailover(('getblockcount', [], height)) + await daemon.height() == height assert in_caplog(caplog, "disconnected", 1) assert in_caplog(caplog, "failing over")