# Copyright (c) 2016, the ElectrumX authors # # All rights reserved. # # See the file "LICENCE" for information about the copyright # and warranty status of this software. '''Backend database abstraction. The abstraction needs to be improved to not heavily penalise LMDB. ''' import os from functools import partial from lib.util import subclasses def open_db(name, db_engine): '''Returns a database handle.''' for db_class in subclasses(Storage): if db_class.__name__.lower() == db_engine.lower(): db_class.import_module() return db_class(name) raise RuntimeError('unrecognised DB engine "{}"'.format(db_engine)) class Storage(object): '''Abstract base class of the DB backend abstraction.''' def __init__(self, name): self.is_new = not os.path.exists(name) self.open(name, create=self.is_new) @classmethod def import_module(cls): '''Import the DB engine module.''' raise NotImplementedError def open(self, name, create): '''Open an existing database or create a new one.''' raise NotImplementedError def get(self, key): raise NotImplementedError def put(self, key, value): raise NotImplementedError def write_batch(self): '''Return a context manager that provides `put` and `delete`. Changes should only be committed when the context manager closes without an exception. ''' raise NotImplementedError def iterator(self, prefix=b'', reverse=False): '''Return an iterator that yields (key, value) pairs from the database sorted by key. If `prefix` is set, only keys starting with `prefix` will be included. If `reverse` is True the items are returned in reverse order. ''' raise NotImplementedError class LevelDB(Storage): '''LevelDB database engine.''' @classmethod def import_module(cls): import plyvel cls.module = plyvel def open(self, name, create): self.db = self.module.DB(name, create_if_missing=create, compression=None) self.get = self.db.get self.put = self.db.put self.iterator = self.db.iterator self.write_batch = partial(self.db.write_batch, transaction=True) class RocksDB(Storage): '''RocksDB database engine.''' @classmethod def import_module(cls): import rocksdb cls.module = rocksdb def open(self, name, create): compression = "no" compression = getattr(self.module.CompressionType, compression + "_compression") options = self.module.Options(create_if_missing=create, compression=compression, target_file_size_base=33554432, max_open_files=1024) self.db = self.module.DB(name, options) self.get = self.db.get self.put = self.db.put class WriteBatch(object): def __init__(self, db): self.batch = RocksDB.module.WriteBatch() self.db = db def __enter__(self): return self.batch def __exit__(self, exc_type, exc_val, exc_tb): if not exc_val: self.db.write(self.batch) def write_batch(self): return RocksDB.WriteBatch(self.db) class Iterator(object): def __init__(self, db, prefix, reverse): self.it = db.iteritems() if reverse: self.it = reversed(self.it) self.prefix = prefix def __iter__(self): self.it.seek(self.prefix) return self def __next__(self): k, v = self.it.__next__() if not k.startswith(self.prefix): # We're already ahead of the prefix raise StopIteration return k, v def iterator(self, prefix=b'', reverse=False): return RocksDB.Iterator(self.db, prefix, reverse) class LMDB(Storage): '''RocksDB database engine.''' @classmethod def import_module(cls): import lmdb cls.module = lmdb def open(self, name, create): self.env = cls.module.Environment('.', subdir=True, create=create, max_dbs=32, map_size=5 * 10 ** 10) self.db = self.env.open_db(create=create) def get(self, key): with self.env.begin(db=self.db) as tx: return tx.get(key) def put(self, key, value): with self.env.begin(db=self.db, write=True) as tx: tx.put(key, value) def write_batch(self): return self.env.begin(db=self.db, write=True) def iterator(self, prefix=b'', reverse=False): return LMDB.Iterator(self.db, self.env, prefix, reverse) class Iterator: def __init__(self, db, env, prefix, reverse): self.transaction = env.begin(db=db) self.transaction.__enter__() self.db = db self.prefix = prefix self.reverse = reverse # FIXME def __iter__(self): self.iterator = LMDB.lmdb.Cursor(self.db, self.transaction) self.iterator.set_range(self.prefix) return self def __next__(self): k, v = self.iterator.item() if not k.startswith(self.prefix) or not self.iterator.next(): # We're already ahead of the prefix self.transaction.__exit__() raise StopIteration return k, v