From c0568daec37f286e15346f88abf8ff47338602ad Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Thu, 3 Nov 2016 13:30:29 +0900 Subject: [PATCH] Rework the DB API a bit --- lib/coins.py | 12 +--- lib/util.py | 9 +++ server/block_processor.py | 77 ++++++++++------------- server/storage.py | 126 ++++++++++++++++++++++++++------------ 4 files changed, 130 insertions(+), 94 deletions(-) diff --git a/lib/coins.py b/lib/coins.py index 5c89832..a87b652 100644 --- a/lib/coins.py +++ b/lib/coins.py @@ -19,6 +19,7 @@ import sys from lib.hash import Base58, hash160, double_sha256, hash_to_str from lib.script import ScriptPubKey from lib.tx import Deserializer +from lib.util import subclasses class CoinError(Exception): @@ -34,21 +35,12 @@ class Coin(object): VALUE_PER_COIN = 100000000 CHUNK_SIZE=2016 - @staticmethod - def coin_classes(): - '''Return a list of coin classes in declaration order.''' - is_coin = lambda obj: (inspect.isclass(obj) - and issubclass(obj, Coin) - and obj != Coin) - pairs = inspect.getmembers(sys.modules[__name__], is_coin) - return [pair[1] for pair in pairs] - @classmethod def lookup_coin_class(cls, name, net): '''Return a coin class given name and network. Raise an exception if unrecognised.''' - for coin in cls.coin_classes(): + for coin in subclasses(Coin): if (coin.NAME.lower() == name.lower() and coin.NET.lower() == net.lower()): return coin diff --git a/lib/util.py b/lib/util.py index f59439b..dd8187e 100644 --- a/lib/util.py +++ b/lib/util.py @@ -9,6 +9,7 @@ import array +import inspect import logging import sys from collections import Container, Mapping @@ -77,6 +78,14 @@ def deep_getsizeof(obj): return size(obj) +def subclasses(base_class, strict=True): + '''Return a list of subclasses of base_class in its module.''' + def select(obj): + return (inspect.isclass(obj) and issubclass(obj, base_class) + and (not strict or obj != base_class)) + + pairs = inspect.getmembers(sys.modules[base_class.__module__], select) + return [pair[1] for pair in pairs] def chunks(items, size): '''Break up items, an iterable, into chunks of length size.''' diff --git a/server/block_processor.py b/server/block_processor.py index 3887204..0bc48e2 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -22,7 +22,7 @@ from server.daemon import DaemonError from lib.hash import hash_to_str from lib.script import ScriptPubKey from lib.util import chunks, LoggedClass -from server.storage import LMDB, RocksDB, LevelDB, NoDatabaseException +from server.storage import open_db def formatted_time(t): @@ -161,17 +161,17 @@ class BlockProcessor(LoggedClass): self.coin = env.coin self.reorg_limit = env.reorg_limit - # Chain state (initialize to genesis in case of new DB) - self.db_height = -1 - self.db_tx_count = 0 - self.db_tip = b'\0' * 32 - self.flush_count = 0 - self.utxo_flush_count = 0 - self.wall_time = 0 - self.first_sync = True - # Open DB and metadata files. Record some of its state. - self.db = self.open_db(self.coin, env.db_engine) + db_name = '{}-{}'.format(self.coin.NAME, self.coin.NET) + self.db = open_db(db_name, env.db_engine) + if self.db.is_new: + self.logger.info('created new {} database {}' + .format(env.db_engine, db_name)) + else: + self.logger.info('successfully opened {} database {}' + .format(env.db_engine, db_name)) + + self.init_state() self.tx_count = self.db_tx_count self.height = self.db_height self.tip = self.db_tip @@ -313,40 +313,29 @@ class BlockProcessor(LoggedClass): return self.fs_cache.block_hashes(start, count) - def open_db(self, coin, db_engine): - db_name = '{}-{}'.format(coin.NAME, coin.NET) - db_engine_class = { - "leveldb": LevelDB, - "rocksdb": RocksDB, - "lmdb": LMDB - }[db_engine.lower()] - try: - db = db_engine_class(db_name, create_if_missing=False, - error_if_exists=False, compression=None) - except NoDatabaseException: - db = db_engine_class(db_name, create_if_missing=True, - error_if_exists=True, compression=None) - self.logger.info('created new {} database {}'.format(db_engine, db_name)) + def init_state(self): + if self.db.is_new: + self.db_height = -1 + self.db_tx_count = 0 + self.db_tip = b'\0' * 32 + self.flush_count = 0 + self.utxo_flush_count = 0 + self.wall_time = 0 + self.first_sync = True else: - self.logger.info('successfully opened {} database {}'.format(db_engine, db_name)) - self.read_state(db) - - return db - - def read_state(self, db): - state = db.get(b'state') - state = ast.literal_eval(state.decode()) - if state['genesis'] != self.coin.GENESIS_HASH: - raise ChainError('DB genesis hash {} does not match coin {}' - .format(state['genesis_hash'], - self.coin.GENESIS_HASH)) - self.db_height = state['height'] - self.db_tx_count = state['tx_count'] - self.db_tip = state['tip'] - self.flush_count = state['flush_count'] - self.utxo_flush_count = state['utxo_flush_count'] - self.wall_time = state['wall_time'] - self.first_sync = state.get('first_sync', True) + state = self.db.get(b'state') + state = ast.literal_eval(state.decode()) + if state['genesis'] != self.coin.GENESIS_HASH: + raise ChainError('DB genesis hash {} does not match coin {}' + .format(state['genesis_hash'], + self.coin.GENESIS_HASH)) + self.db_height = state['height'] + self.db_tx_count = state['tx_count'] + self.db_tip = state['tip'] + self.flush_count = state['flush_count'] + self.utxo_flush_count = state['utxo_flush_count'] + self.wall_time = state['wall_time'] + self.first_sync = state.get('first_sync', True) def clean_db(self): '''Clean out stale DB items. diff --git a/server/storage.py b/server/storage.py index 399ed69..d4557d4 100644 --- a/server/storage.py +++ b/server/storage.py @@ -1,43 +1,83 @@ +# Copyright (c) 2016, the ElectrumX authors +# +# All rights reserved. +# +# See the file "LICENCE" for information about the copyright +# and warranty status of this software. + +'''Backend database abstraction. + +The abstraction needs to be improved to not heavily penalise LMDB. +''' + import os from functools import partial +from lib.util import subclasses + + +def open_db(name, db_engine): + '''Returns a database handle.''' + for db_class in subclasses(Storage): + if db_class.__name__.lower() == db_engine.lower(): + db_class.import_module() + return db_class(name) + + raise RuntimeError('unrecognised DB engine "{}"'.format(db_engine)) + class Storage(object): - def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): - if not create_if_missing and not os.path.exists(name): - raise NoDatabaseException + '''Abstract base class of the DB backend abstraction.''' + + def __init__(self, name): + self.is_new = not os.path.exists(name) + self.open(name, create=self.is_new) + + @classmethod + def import_module(cls): + '''Import the DB engine module.''' + raise NotImplementedError + + def open(self, name, create): + '''Open an existing database or create a new one.''' + raise NotImplementedError def get(self, key): - raise NotImplementedError() + raise NotImplementedError def put(self, key, value): - raise NotImplementedError() + raise NotImplementedError def write_batch(self): - """ - Returns a context manager that provides `put` and `delete`. - Changes should only be committed when the context manager closes without an exception. - """ - raise NotImplementedError() + '''Return a context manager that provides `put` and `delete`. - def iterator(self, prefix=b'', reverse=False): - """ - Returns an iterator that yields (key, value) pairs from the database sorted by key. - If `prefix` is set, only keys starting with `prefix` will be included. - """ - raise NotImplementedError() + Changes should only be committed when the context manager + closes without an exception. + ''' + raise NotImplementedError + def iterator(self, prefix=b'', reverse=False): + '''Return an iterator that yields (key, value) pairs from the + database sorted by key. -class NoDatabaseException(Exception): - pass + If `prefix` is set, only keys starting with `prefix` will be + included. If `reverse` is True the items are returned in + reverse order. + ''' + raise NotImplementedError class LevelDB(Storage): - def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): - super().__init__(name, create_if_missing, error_if_exists, compression) + '''LevelDB database engine.''' + + @classmethod + def import_module(cls): import plyvel - self.db = plyvel.DB(name, create_if_missing=create_if_missing, - error_if_exists=error_if_exists, compression=compression) + cls.module = plyvel + + def open(self, name, create): + self.db = self.module.DB(name, create_if_missing=create, + compression=None) self.get = self.db.get self.put = self.db.put self.iterator = self.db.iterator @@ -45,25 +85,28 @@ class LevelDB(Storage): class RocksDB(Storage): - rocksdb = None + '''RocksDB database engine.''' - def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): - super().__init__(name, create_if_missing, error_if_exists, compression) + @classmethod + def import_module(cls): import rocksdb - RocksDB.rocksdb = rocksdb - if not compression: - compression = "no" - compression = getattr(rocksdb.CompressionType, compression + "_compression") - self.db = rocksdb.DB(name, rocksdb.Options(create_if_missing=create_if_missing, - compression=compression, - target_file_size_base=33554432, - max_open_files=1024)) + cls.module = rocksdb + + def open(self, name, create): + compression = "no" + compression = getattr(self.module.CompressionType, + compression + "_compression") + options = self.module.Options(create_if_missing=create, + compression=compression, + target_file_size_base=33554432, + max_open_files=1024) + self.db = self.module.DB(name, options) self.get = self.db.get self.put = self.db.put class WriteBatch(object): def __init__(self, db): - self.batch = RocksDB.rocksdb.WriteBatch() + self.batch = RocksDB.module.WriteBatch() self.db = db def __enter__(self): @@ -99,14 +142,17 @@ class RocksDB(Storage): class LMDB(Storage): - lmdb = None + '''RocksDB database engine.''' - def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): - super().__init__(name, create_if_missing, error_if_exists, compression) + @classmethod + def import_module(cls): import lmdb - LMDB.lmdb = lmdb - self.env = lmdb.Environment(".", subdir=True, create=create_if_missing, max_dbs=32, map_size=5 * 10 ** 10) - self.db = self.env.open_db(create=create_if_missing) + cls.module = lmdb + + def open(self, name, create): + self.env = cls.module.Environment('.', subdir=True, create=create, + max_dbs=32, map_size=5 * 10 ** 10) + self.db = self.env.open_db(create=create) def get(self, key): with self.env.begin(db=self.db) as tx: