Browse Source

Rework the DB API a bit

master
Neil Booth 8 years ago
parent
commit
c0568daec3
  1. 12
      lib/coins.py
  2. 9
      lib/util.py
  3. 77
      server/block_processor.py
  4. 126
      server/storage.py

12
lib/coins.py

@ -19,6 +19,7 @@ import sys
from lib.hash import Base58, hash160, double_sha256, hash_to_str from lib.hash import Base58, hash160, double_sha256, hash_to_str
from lib.script import ScriptPubKey from lib.script import ScriptPubKey
from lib.tx import Deserializer from lib.tx import Deserializer
from lib.util import subclasses
class CoinError(Exception): class CoinError(Exception):
@ -34,21 +35,12 @@ class Coin(object):
VALUE_PER_COIN = 100000000 VALUE_PER_COIN = 100000000
CHUNK_SIZE=2016 CHUNK_SIZE=2016
@staticmethod
def coin_classes():
'''Return a list of coin classes in declaration order.'''
is_coin = lambda obj: (inspect.isclass(obj)
and issubclass(obj, Coin)
and obj != Coin)
pairs = inspect.getmembers(sys.modules[__name__], is_coin)
return [pair[1] for pair in pairs]
@classmethod @classmethod
def lookup_coin_class(cls, name, net): def lookup_coin_class(cls, name, net):
'''Return a coin class given name and network. '''Return a coin class given name and network.
Raise an exception if unrecognised.''' Raise an exception if unrecognised.'''
for coin in cls.coin_classes(): for coin in subclasses(Coin):
if (coin.NAME.lower() == name.lower() if (coin.NAME.lower() == name.lower()
and coin.NET.lower() == net.lower()): and coin.NET.lower() == net.lower()):
return coin return coin

9
lib/util.py

@ -9,6 +9,7 @@
import array import array
import inspect
import logging import logging
import sys import sys
from collections import Container, Mapping from collections import Container, Mapping
@ -77,6 +78,14 @@ def deep_getsizeof(obj):
return size(obj) return size(obj)
def subclasses(base_class, strict=True):
'''Return a list of subclasses of base_class in its module.'''
def select(obj):
return (inspect.isclass(obj) and issubclass(obj, base_class)
and (not strict or obj != base_class))
pairs = inspect.getmembers(sys.modules[base_class.__module__], select)
return [pair[1] for pair in pairs]
def chunks(items, size): def chunks(items, size):
'''Break up items, an iterable, into chunks of length size.''' '''Break up items, an iterable, into chunks of length size.'''

77
server/block_processor.py

@ -22,7 +22,7 @@ from server.daemon import DaemonError
from lib.hash import hash_to_str from lib.hash import hash_to_str
from lib.script import ScriptPubKey from lib.script import ScriptPubKey
from lib.util import chunks, LoggedClass from lib.util import chunks, LoggedClass
from server.storage import LMDB, RocksDB, LevelDB, NoDatabaseException from server.storage import open_db
def formatted_time(t): def formatted_time(t):
@ -161,17 +161,17 @@ class BlockProcessor(LoggedClass):
self.coin = env.coin self.coin = env.coin
self.reorg_limit = env.reorg_limit self.reorg_limit = env.reorg_limit
# Chain state (initialize to genesis in case of new DB)
self.db_height = -1
self.db_tx_count = 0
self.db_tip = b'\0' * 32
self.flush_count = 0
self.utxo_flush_count = 0
self.wall_time = 0
self.first_sync = True
# Open DB and metadata files. Record some of its state. # Open DB and metadata files. Record some of its state.
self.db = self.open_db(self.coin, env.db_engine) db_name = '{}-{}'.format(self.coin.NAME, self.coin.NET)
self.db = open_db(db_name, env.db_engine)
if self.db.is_new:
self.logger.info('created new {} database {}'
.format(env.db_engine, db_name))
else:
self.logger.info('successfully opened {} database {}'
.format(env.db_engine, db_name))
self.init_state()
self.tx_count = self.db_tx_count self.tx_count = self.db_tx_count
self.height = self.db_height self.height = self.db_height
self.tip = self.db_tip self.tip = self.db_tip
@ -313,40 +313,29 @@ class BlockProcessor(LoggedClass):
return self.fs_cache.block_hashes(start, count) return self.fs_cache.block_hashes(start, count)
def open_db(self, coin, db_engine): def init_state(self):
db_name = '{}-{}'.format(coin.NAME, coin.NET) if self.db.is_new:
db_engine_class = { self.db_height = -1
"leveldb": LevelDB, self.db_tx_count = 0
"rocksdb": RocksDB, self.db_tip = b'\0' * 32
"lmdb": LMDB self.flush_count = 0
}[db_engine.lower()] self.utxo_flush_count = 0
try: self.wall_time = 0
db = db_engine_class(db_name, create_if_missing=False, self.first_sync = True
error_if_exists=False, compression=None)
except NoDatabaseException:
db = db_engine_class(db_name, create_if_missing=True,
error_if_exists=True, compression=None)
self.logger.info('created new {} database {}'.format(db_engine, db_name))
else: else:
self.logger.info('successfully opened {} database {}'.format(db_engine, db_name)) state = self.db.get(b'state')
self.read_state(db) state = ast.literal_eval(state.decode())
if state['genesis'] != self.coin.GENESIS_HASH:
return db raise ChainError('DB genesis hash {} does not match coin {}'
.format(state['genesis_hash'],
def read_state(self, db): self.coin.GENESIS_HASH))
state = db.get(b'state') self.db_height = state['height']
state = ast.literal_eval(state.decode()) self.db_tx_count = state['tx_count']
if state['genesis'] != self.coin.GENESIS_HASH: self.db_tip = state['tip']
raise ChainError('DB genesis hash {} does not match coin {}' self.flush_count = state['flush_count']
.format(state['genesis_hash'], self.utxo_flush_count = state['utxo_flush_count']
self.coin.GENESIS_HASH)) self.wall_time = state['wall_time']
self.db_height = state['height'] self.first_sync = state.get('first_sync', True)
self.db_tx_count = state['tx_count']
self.db_tip = state['tip']
self.flush_count = state['flush_count']
self.utxo_flush_count = state['utxo_flush_count']
self.wall_time = state['wall_time']
self.first_sync = state.get('first_sync', True)
def clean_db(self): def clean_db(self):
'''Clean out stale DB items. '''Clean out stale DB items.

126
server/storage.py

@ -1,43 +1,83 @@
# Copyright (c) 2016, the ElectrumX authors
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
'''Backend database abstraction.
The abstraction needs to be improved to not heavily penalise LMDB.
'''
import os import os
from functools import partial from functools import partial
from lib.util import subclasses
def open_db(name, db_engine):
'''Returns a database handle.'''
for db_class in subclasses(Storage):
if db_class.__name__.lower() == db_engine.lower():
db_class.import_module()
return db_class(name)
raise RuntimeError('unrecognised DB engine "{}"'.format(db_engine))
class Storage(object): class Storage(object):
def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): '''Abstract base class of the DB backend abstraction.'''
if not create_if_missing and not os.path.exists(name):
raise NoDatabaseException def __init__(self, name):
self.is_new = not os.path.exists(name)
self.open(name, create=self.is_new)
@classmethod
def import_module(cls):
'''Import the DB engine module.'''
raise NotImplementedError
def open(self, name, create):
'''Open an existing database or create a new one.'''
raise NotImplementedError
def get(self, key): def get(self, key):
raise NotImplementedError() raise NotImplementedError
def put(self, key, value): def put(self, key, value):
raise NotImplementedError() raise NotImplementedError
def write_batch(self): def write_batch(self):
""" '''Return a context manager that provides `put` and `delete`.
Returns a context manager that provides `put` and `delete`.
Changes should only be committed when the context manager closes without an exception.
"""
raise NotImplementedError()
def iterator(self, prefix=b'', reverse=False): Changes should only be committed when the context manager
""" closes without an exception.
Returns an iterator that yields (key, value) pairs from the database sorted by key. '''
If `prefix` is set, only keys starting with `prefix` will be included. raise NotImplementedError
"""
raise NotImplementedError()
def iterator(self, prefix=b'', reverse=False):
'''Return an iterator that yields (key, value) pairs from the
database sorted by key.
class NoDatabaseException(Exception): If `prefix` is set, only keys starting with `prefix` will be
pass included. If `reverse` is True the items are returned in
reverse order.
'''
raise NotImplementedError
class LevelDB(Storage): class LevelDB(Storage):
def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): '''LevelDB database engine.'''
super().__init__(name, create_if_missing, error_if_exists, compression)
@classmethod
def import_module(cls):
import plyvel import plyvel
self.db = plyvel.DB(name, create_if_missing=create_if_missing, cls.module = plyvel
error_if_exists=error_if_exists, compression=compression)
def open(self, name, create):
self.db = self.module.DB(name, create_if_missing=create,
compression=None)
self.get = self.db.get self.get = self.db.get
self.put = self.db.put self.put = self.db.put
self.iterator = self.db.iterator self.iterator = self.db.iterator
@ -45,25 +85,28 @@ class LevelDB(Storage):
class RocksDB(Storage): class RocksDB(Storage):
rocksdb = None '''RocksDB database engine.'''
def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): @classmethod
super().__init__(name, create_if_missing, error_if_exists, compression) def import_module(cls):
import rocksdb import rocksdb
RocksDB.rocksdb = rocksdb cls.module = rocksdb
if not compression:
compression = "no" def open(self, name, create):
compression = getattr(rocksdb.CompressionType, compression + "_compression") compression = "no"
self.db = rocksdb.DB(name, rocksdb.Options(create_if_missing=create_if_missing, compression = getattr(self.module.CompressionType,
compression=compression, compression + "_compression")
target_file_size_base=33554432, options = self.module.Options(create_if_missing=create,
max_open_files=1024)) compression=compression,
target_file_size_base=33554432,
max_open_files=1024)
self.db = self.module.DB(name, options)
self.get = self.db.get self.get = self.db.get
self.put = self.db.put self.put = self.db.put
class WriteBatch(object): class WriteBatch(object):
def __init__(self, db): def __init__(self, db):
self.batch = RocksDB.rocksdb.WriteBatch() self.batch = RocksDB.module.WriteBatch()
self.db = db self.db = db
def __enter__(self): def __enter__(self):
@ -99,14 +142,17 @@ class RocksDB(Storage):
class LMDB(Storage): class LMDB(Storage):
lmdb = None '''RocksDB database engine.'''
def __init__(self, name, create_if_missing=False, error_if_exists=False, compression=None): @classmethod
super().__init__(name, create_if_missing, error_if_exists, compression) def import_module(cls):
import lmdb import lmdb
LMDB.lmdb = lmdb cls.module = lmdb
self.env = lmdb.Environment(".", subdir=True, create=create_if_missing, max_dbs=32, map_size=5 * 10 ** 10)
self.db = self.env.open_db(create=create_if_missing) def open(self, name, create):
self.env = cls.module.Environment('.', subdir=True, create=create,
max_dbs=32, map_size=5 * 10 ** 10)
self.db = self.env.open_db(create=create)
def get(self, key): def get(self, key):
with self.env.begin(db=self.db) as tx: with self.env.begin(db=self.db) as tx:

Loading…
Cancel
Save