From 53425ce585c6ce58cb12ac408ecfc38e05a3954c Mon Sep 17 00:00:00 2001
From: Neil Booth <kyuupichan@gmail.com>
Date: Wed, 18 Jul 2018 11:02:29 +0800
Subject: [PATCH] Move task logic to Tasks object

This helps to rationalize the inter-object
dependencies.
---
 electrumx/lib/tasks.py              | 67 +++++++++++++++++++++++++++++
 electrumx/server/block_processor.py | 27 ++++++++----
 electrumx/server/controller.py      | 63 ++++++++-------------------
 electrumx/server/mempool.py         | 15 ++++---
 electrumx/server/peers.py           | 15 ++++---
 electrumx/server/session.py         | 33 +++++++-------
 6 files changed, 136 insertions(+), 84 deletions(-)
 create mode 100644 electrumx/lib/tasks.py

diff --git a/electrumx/lib/tasks.py b/electrumx/lib/tasks.py
new file mode 100644
index 0000000..6b540a9
--- /dev/null
+++ b/electrumx/lib/tasks.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2018, Neil Booth
+#
+# All rights reserved.
+#
+# The MIT License (MIT)
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+# and warranty status of this software.
+
+'''Concurrency via tasks and threads.'''
+
+from concurrent.futures import ThreadPoolExecutor
+
+from aiorpcx import TaskSet
+
+import electrumx.lib.util as util
+
+
+class Tasks(object):
+    # Functionality here will be incorporated into aiorpcX's TaskSet
+    # after experience is gained.
+
+    def __init__(self, *, loop=None):
+        self.tasks = TaskSet(loop=loop)
+        self.logger = util.class_logger(__name__, self.__class__.__name__)
+        # FIXME: is the executor still needed?
+        self.executor = ThreadPoolExecutor()
+        self.tasks.loop.set_default_executor(self.executor)
+        # Pass through until integrated
+        self.loop = self.tasks.loop
+        self.cancel_all = self.tasks.cancel_all
+        self.wait = self.tasks.wait
+
+    async def run_in_thread(self, func, *args):
+        '''Run a function in a separate thread, and await its completion.'''
+        return await self.loop.run_in_executor(None, func, *args)
+
+    def create_task(self, coro, callback=None):
+        '''Schedule the coro to be run.'''
+        task = self.tasks.create_task(coro)
+        task.add_done_callback(callback or self._check_task_exception)
+        return task
+
+    def _check_task_exception(self, task):
+        '''Check a task for exceptions.'''
+        try:
+            if not task.cancelled():
+                task.result()
+        except Exception as e:
+            self.logger.exception(f'uncaught task exception: {e}')
diff --git a/electrumx/server/block_processor.py b/electrumx/server/block_processor.py
index 4b4985e..55c6ef2 100644
--- a/electrumx/server/block_processor.py
+++ b/electrumx/server/block_processor.py
@@ -146,15 +146,15 @@ class BlockProcessor(electrumx.server.db.DB):
     Coordinate backing up in case of chain reorganisations.
     '''
 
-    def __init__(self, env, controller, daemon):
+    def __init__(self, env, tasks, daemon):
         super().__init__(env)
 
         # An incomplete compaction needs to be cancelled otherwise
         # restarting it will corrupt the history
         self.history.cancel_compaction()
 
+        self.tasks = tasks
         self.daemon = daemon
-        self.controller = controller
 
         # These are our state as we move ahead of DB state
         self.fs_height = self.db_height
@@ -172,6 +172,7 @@ class BlockProcessor(electrumx.server.db.DB):
         self.last_flush = time.time()
         self.last_flush_tx_count = self.tx_count
         self.touched = set()
+        self.callbacks = []
 
         # Header merkle cache
         self.merkle = Merkle()
@@ -204,9 +205,18 @@ class BlockProcessor(electrumx.server.db.DB):
         '''Called by the prefetcher when it first catches up.'''
         self.add_task(self.first_caught_up)
 
+    def add_new_block_callback(self, callback):
+        '''Add a function called when a new block is found.
+
+        If several blocks are processed simultaneously, only called
+        once.  The callback is passed a set of hashXs touched by the
+        block(s), which is cleared on return.
+        '''
+        self.callbacks.append(callback)
+
     async def main_loop(self):
         '''Main loop for block processing.'''
-        self.controller.create_task(self.prefetcher.main_loop())
+        self.tasks.create_task(self.prefetcher.main_loop())
         await self.prefetcher.reset_height()
 
         while True:
@@ -226,7 +236,7 @@ class BlockProcessor(electrumx.server.db.DB):
         '''Called when first caught up to daemon after starting.'''
         # Flush everything with updated first_sync->False state.
         self.first_sync = False
-        await self.controller.run_in_executor(self.flush, True)
+        await self.tasks.run_in_thread(self.flush, True)
         if self.utxo_db.for_sync:
             self.logger.info(f'{electrumx.version} synced to '
                              f'height {self.height:,d}')
@@ -261,13 +271,14 @@ class BlockProcessor(electrumx.server.db.DB):
 
         if hprevs == chain:
             start = time.time()
-            await self.controller.run_in_executor(self.advance_blocks, blocks)
+            await self.tasks.run_in_thread(self.advance_blocks, blocks)
             if not self.first_sync:
                 s = '' if len(blocks) == 1 else 's'
                 self.logger.info('processed {:,d} block{} in {:.1f}s'
                                  .format(len(blocks), s,
                                          time.time() - start))
-                self.controller.mempool.on_new_block(self.touched)
+                for callback in self.callbacks:
+                    callback(self.touched)
             self.touched.clear()
         elif hprevs[0] != chain[0]:
             await self.reorg_chain()
@@ -300,14 +311,14 @@ class BlockProcessor(electrumx.server.db.DB):
             self.logger.info('chain reorg detected')
         else:
             self.logger.info('faking a reorg of {:,d} blocks'.format(count))
-        await self.controller.run_in_executor(self.flush, True)
+        await self.tasks.run_in_thread(self.flush, True)
 
         hashes = await self.reorg_hashes(count)
         # Reverse and convert to hex strings.
         hashes = [hash_to_hex_str(hash) for hash in reversed(hashes)]
         for hex_hashes in chunks(hashes, 50):
             blocks = await self.daemon.raw_blocks(hex_hashes)
-            await self.controller.run_in_executor(self.backup_blocks, blocks)
+            await self.tasks.run_in_thread(self.backup_blocks, blocks)
         # Truncate header_mc: header count is 1 more than the height
         self.header_mc.truncate(self.height + 1)
         await self.prefetcher.reset_height()
diff --git a/electrumx/server/controller.py b/electrumx/server/controller.py
index deb013a..b7accab 100644
--- a/electrumx/server/controller.py
+++ b/electrumx/server/controller.py
@@ -5,12 +5,10 @@
 # See the file "LICENCE" for information about the copyright
 # and warranty status of this software.
 
-import asyncio
-from concurrent.futures import ThreadPoolExecutor
-
-from aiorpcx import TaskSet, _version as aiorpcx_version
+from aiorpcx import _version as aiorpcx_version
 import electrumx
 from electrumx.lib.server_base import ServerBase
+from electrumx.lib.tasks import Tasks
 from electrumx.lib.util import version_string
 from electrumx.server.mempool import MemPool
 from electrumx.server.peers import PeerManager
@@ -40,28 +38,23 @@ class Controller(ServerBase):
         self.logger.info(f'supported protocol versions: {min_str}-{max_str}')
         self.logger.info(f'event loop policy: {env.loop_policy}')
 
-        self.coin = env.coin
-        self.tasks = TaskSet()
         env.max_send = max(350000, env.max_send)
 
-        self.loop = asyncio.get_event_loop()
-        self.executor = ThreadPoolExecutor()
-        self.loop.set_default_executor(self.executor)
-
-        # The complex objects.  Note PeerManager references self.loop (ugh)
-        self.session_mgr = SessionManager(env, self)
-        self.daemon = self.coin.DAEMON(env)
-        self.bp = self.coin.BLOCK_PROCESSOR(env, self, self.daemon)
-        self.mempool = MemPool(self.bp, self)
-        self.peer_mgr = PeerManager(env, self)
+        self.tasks = Tasks()
+        self.session_mgr = SessionManager(env, self.tasks, self)
+        self.daemon = env.coin.DAEMON(env)
+        self.bp = env.coin.BLOCK_PROCESSOR(env, self.tasks, self.daemon)
+        self.mempool = MemPool(self.bp, self.daemon, self.tasks,
+                               self.session_mgr.notify_sessions)
+        self.peer_mgr = PeerManager(env, self.tasks, self.session_mgr, self.bp)
 
     async def start_servers(self):
         '''Start the RPC server and schedule the external servers to be
         started once the block processor has caught up.
         '''
         await self.session_mgr.start_rpc_server()
-        self.create_task(self.bp.main_loop())
-        self.create_task(self.wait_for_bp_catchup())
+        self.tasks.create_task(self.bp.main_loop())
+        self.tasks.create_task(self.wait_for_bp_catchup())
 
     async def shutdown(self):
         '''Perform the shutdown sequence.'''
@@ -69,8 +62,8 @@ class Controller(ServerBase):
         self.tasks.cancel_all()
         await self.session_mgr.shutdown()
         await self.tasks.wait()
-        # Finally shut down the block processor and executor
-        self.bp.shutdown(self.executor)
+        # Finally shut down the block processor and executor (FIXME)
+        self.bp.shutdown(self.tasks.executor)
 
     async def mempool_transactions(self, hashX):
         '''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool
@@ -87,34 +80,12 @@ class Controller(ServerBase):
         '''
         return self.mempool.value(hashX)
 
-    async def run_in_executor(self, func, *args):
-        '''Wait whilst running func in the executor.'''
-        return await self.loop.run_in_executor(None, func, *args)
-
-    def schedule_executor(self, func, *args):
-        '''Schedule running func in the executor, return a task.'''
-        return self.create_task(self.run_in_executor(func, *args))
-
-    def create_task(self, coro, callback=None):
-        '''Schedule the coro to be run.'''
-        task = self.tasks.create_task(coro)
-        task.add_done_callback(callback or self.check_task_exception)
-        return task
-
-    def check_task_exception(self, task):
-        '''Check a task for exceptions.'''
-        try:
-            if not task.cancelled():
-                task.result()
-        except Exception as e:
-            self.logger.exception(f'uncaught task exception: {e}')
-
     async def wait_for_bp_catchup(self):
         '''Wait for the block processor to catch up, and for the mempool to
         synchronize, then kick off server background processes.'''
         await self.bp.caught_up_event.wait()
-        self.create_task(self.mempool.main_loop())
+        self.tasks.create_task(self.mempool.main_loop())
         await self.mempool.synchronized_event.wait()
-        self.create_task(self.peer_mgr.main_loop())
-        self.create_task(self.session_mgr.start_serving())
-        self.create_task(self.session_mgr.housekeeping())
+        self.tasks.create_task(self.peer_mgr.main_loop())
+        self.tasks.create_task(self.session_mgr.start_serving())
+        self.tasks.create_task(self.session_mgr.housekeeping())
diff --git a/electrumx/server/mempool.py b/electrumx/server/mempool.py
index 2ff9209..b08de09 100644
--- a/electrumx/server/mempool.py
+++ b/electrumx/server/mempool.py
@@ -32,13 +32,13 @@ class MemPool(object):
     A pair is a (hashX, value) tuple.  tx hashes are hex strings.
     '''
 
-    def __init__(self, bp, controller):
+    def __init__(self, db, daemon, tasks, notify_sessions):
         self.logger = class_logger(__name__, self.__class__.__name__)
-        self.daemon = bp.daemon
-        self.controller = controller
-        self.notify_sessions = controller.session_mgr.notify_sessions
-        self.coin = bp.coin
-        self.db = bp
+        self.db = db
+        self.daemon = daemon
+        self.tasks = tasks
+        self.notify_sessions = notify_sessions
+        self.coin = db.coin
         self.touched = set()
         self.stop = False
         self.txs = {}
@@ -47,6 +47,7 @@ class MemPool(object):
         self.fee_histogram = defaultdict(int)
         self.compact_fee_histogram = []
         self.histogram_time = 0
+        db.add_new_block_callback(self.on_new_block)
 
     def _resync_daemon_hashes(self, unprocessed, unfetched):
         '''Re-sync self.txs with the list of hashes in the daemon's mempool.
@@ -165,7 +166,7 @@ class MemPool(object):
                 deferred = pending
                 pending = []
 
-            result, deferred = await self.controller.run_in_executor(
+            result, deferred = await self.tasks.run_in_thread(
                 self.process_raw_txs, raw_txs, deferred)
 
             pending.extend(deferred)
diff --git a/electrumx/server/peers.py b/electrumx/server/peers.py
index a131b06..d6bb6a0 100644
--- a/electrumx/server/peers.py
+++ b/electrumx/server/peers.py
@@ -141,8 +141,7 @@ class PeerSession(ClientSession):
             return
 
         result = request.result()
-        controller = self.peer_mgr.controller
-        our_height = controller.bp.db_height
+        our_height = self.peer_mgr.bp.db_height
         if self.ptuple < (1, 3):
             their_height = result.get('block_height')
         else:
@@ -156,7 +155,7 @@ class PeerSession(ClientSession):
             return
         # Check prior header too in case of hard fork.
         check_height = min(our_height, their_height)
-        raw_header = controller.session_mgr.raw_header(check_height)
+        raw_header = self.peer_mgr.session_mgr.raw_header(check_height)
         if self.ptuple >= (1, 4):
             self.send_request('blockchain.block.header', [check_height],
                               partial(self.on_header, raw_header.hex()),
@@ -241,13 +240,15 @@ class PeerManager(object):
     Attempts to maintain a connection with up to 8 peers.
     Issues a 'peers.subscribe' RPC to them and tells them our data.
     '''
-    def __init__(self, env, controller):
+    def __init__(self, env, tasks, session_mgr, bp):
         self.logger = class_logger(__name__, self.__class__.__name__)
         # Initialise the Peer class
         Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
         self.env = env
-        self.controller = controller
-        self.loop = controller.loop
+        self.tasks = tasks
+        self.session_mgr = session_mgr
+        self.bp = bp
+        self.loop = tasks.loop
 
         # Our clearnet and Tor Peers, if any
         sclass = env.coin.SESSIONCLS
@@ -572,7 +573,7 @@ class PeerManager(object):
 
         session = PeerSession(peer, self, kind, peer.host, port, **kwargs)
         callback = partial(self.on_connected, peer, port_pairs)
-        self.controller.create_task(session.create_connection(), callback)
+        self.tasks.create_task(session.create_connection(), callback)
 
     def on_connected(self, peer, port_pairs, task):
         '''Called when a connection attempt succeeds or fails.
diff --git a/electrumx/server/session.py b/electrumx/server/session.py
index 4fb7cfc..6ac8293 100644
--- a/electrumx/server/session.py
+++ b/electrumx/server/session.py
@@ -98,8 +98,9 @@ class SessionManager(object):
 
     CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
 
-    def __init__(self, env, controller):
+    def __init__(self, env, tasks, controller):
         self.env = env
+        self.tasks = tasks
         self.controller = controller
         self.logger = util.class_logger(__name__, self.__class__.__name__)
         self.servers = {}
@@ -416,13 +417,12 @@ class SessionManager(object):
 
         # Height notifications are synchronous.  Those sessions with
         # touched addresses are scheduled for asynchronous completion
-        create_task = self.controller.create_task
         for session in self.sessions:
             if isinstance(session, LocalRPC):
                 continue
             session_touched = session.notify(height, touched)
             if session_touched is not None:
-                create_task(session.notify_async(session_touched))
+                self.tasks.create_task(session.notify_async(session_touched))
 
     def raw_header(self, height):
         '''Return the binary header at the given height.'''
@@ -442,13 +442,19 @@ class SessionManager(object):
             # on bloated history requests, and uses a smaller divisor
             # so large requests are logged before refusing them.
             limit = self.env.max_send // 97
-            return list(controller.bp.get_history(hashX, limit=limit))
+            return list(self.controller.bp.get_history(hashX, limit=limit))
 
-        controller = self.controller
-        history = await controller.run_in_executor(job)
+        history = await self.tasks.run_in_thread(job)
         self.history_cache[hashX] = history
         return history
 
+    async def get_utxos(self, hashX):
+        '''Get UTXOs asynchronously to reduce latency.'''
+        def job():
+            return list(self.controller.bp.get_utxos(hashX, limit=None))
+
+        return await self.tasks.run_in_thread(job)
+
     async def housekeeping(self):
         '''Regular housekeeping checks.'''
         n = 0
@@ -776,17 +782,10 @@ class ElectrumX(SessionBase):
 
         return status
 
-    async def get_utxos(self, hashX):
-        '''Get UTXOs asynchronously to reduce latency.'''
-        def job():
-            return list(self.bp.get_utxos(hashX, limit=None))
-
-        return await self.controller.run_in_executor(job)
-
     async def hashX_listunspent(self, hashX):
         '''Return the list of UTXOs of a script hash, including mempool
         effects.'''
-        utxos = await self.get_utxos(hashX)
+        utxos = await self.session_mgr.get_utxos(hashX)
         utxos = sorted(utxos)
         utxos.extend(self.controller.mempool.get_utxos(hashX))
         spends = await self.controller.mempool.potential_spends(hashX)
@@ -843,7 +842,7 @@ class ElectrumX(SessionBase):
         return await self.hashX_subscribe(hashX, address)
 
     async def get_balance(self, hashX):
-        utxos = await self.get_utxos(hashX)
+        utxos = await self.session_mgr.get_utxos(hashX)
         confirmed = sum(utxo.value for utxo in utxos)
         unconfirmed = self.controller.mempool_value(hashX)
         return {'confirmed': confirmed, 'unconfirmed': unconfirmed}
@@ -1263,7 +1262,9 @@ class DashElectrumX(ElectrumX):
     def notify(self, height, touched):
         '''Notify the client about changes in masternode list.'''
         result = super().notify(height, touched)
-        self.controller.create_task(self.notify_masternodes_async())
+        # FIXME: the notifications should be done synchronously and the
+        # master node list fetched once asynchronously
+        self.session_mgr.tasks.create_task(self.notify_masternodes_async())
         return result
 
     # Masternode command handlers