diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 054a337ab..2b54b395d 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -168,7 +168,7 @@ class AddressSynchronizer(PrintError): def add_address(self, address): if not self.db.get_addr_history(address): - self.db.history[address] = [] + #self.db.history[address] = [] self.set_up_to_date(False) if self.synchronizer: self.synchronizer.add(address) diff --git a/electrum/json_db.py b/electrum/json_db.py index 84b3d6bdf..b6015fb40 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -115,9 +115,6 @@ class JsonDB(PrintError): return True return False - def commit(self): - pass - @locked def dump(self): return json.dumps(self.data, indent=4, sort_keys=True, cls=JsonDBJsonEncoder) diff --git a/electrum/sqlite_db.py b/electrum/sqlite_db.py new file mode 100644 index 000000000..42097d3b1 --- /dev/null +++ b/electrum/sqlite_db.py @@ -0,0 +1,325 @@ +import os +import sys +import json +from collections import defaultdict +from operator import itemgetter +import concurrent +import threading +import queue +import traceback + +from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean +from sqlalchemy.pool import StaticPool +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.query import Query +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import not_, or_ +from sqlalchemy.orm import scoped_session + +from .util import PrintError, standardize_path, TxMinedInfo +from . import transaction + +Base = declarative_base() +session_factory = sessionmaker() + + + +class Misc(Base): + __tablename__ = 'misc' + key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + value = Column(String()) + +class Txi(Base): + __tablename__ = 'txi' + id = Column(Integer, primary_key=True) + txid = Column(String()) + address = Column(String()) + ser = Column(String()) + v = Column(Integer()) + +class Txo(Base): + __tablename__ = 'txo' + id = Column(Integer, primary_key=True) + txid = Column(String()) + address = Column(String()) + n = Column(Integer()) + v = Column(Integer()) + is_coinbase = Column(Boolean()) + +class Transaction(Base): + __tablename__ = 'transaction' + txid = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + raw = Column(String()) + +class Spent(Base): + __tablename__ = 'spent_outpoints' + id = Column(Integer, primary_key=True) + prev_txid = Column(String(66)) + prev_n = Column(Integer()) + ser = Column(String()) + +class ImportedAddress(Base): + __tablename__ = 'imported_addresses' + id = Column(Integer, primary_key=True) + address = Column(String(66)) + +class ReceivingAddress(Base): + __tablename__ = 'receiving_addresses' + id = Column(Integer, primary_key=True) + address = Column(String(66)) + +class ChangeAddress(Base): + __tablename__ = 'change_addresses' + id = Column(Integer, primary_key=True) + address = Column(String(66)) + +class AddressHistory(Base): + __tablename__ = 'address_history' + id = Column(Integer, primary_key=True) + address = Column(String(66)) + txid = Column(String(66)) + height = Column(Integer) + +class Keystore(Base): + __tablename__ = 'keystore' + name = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + type = Column(String(10)) + pw_hash_version = Column(Integer()) + seed = Column(String()) + xpub = Column(String()) + xprv = Column(String()) + +class VerifiedTx(Base): + __tablename__ = 'verified_tx' + txid = Column(String(), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + height = Column(Integer) + timestamp = Column(Integer) + txpos = Column(Integer) + header_hash = Column(String()) + + +class SqliteDB(PrintError): + + def __init__(self, raw, manual_upgrades=False): + # create db file in memory + self.tmp_path = '/dev/shm/xx' + self._modified = True + with open(self.tmp_path, 'wb') as f: + f.write(raw) + self.exists = bool(raw) + self.requests = queue.Queue() + threading.Thread(target=self.sql_thread, daemon=True).start() + + def sql_thread(self): + self.engine = create_engine('sqlite:///' + self.tmp_path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) + self.DBSession = scoped_session(session_factory) + self.DBSession.configure(bind=self.engine, autoflush=False) + self.DBSession.remove() + if not self.exists: + Base.metadata.create_all(self.engine) + while True: + future, func, args, kwargs = self.requests.get() + try: + result = func(self, *args, **kwargs) + except BaseException as e: + future.set_exception(e) + continue + future.set_result(result) + + def sql(func): + def wrapper(self, *args, **kwargs): + f = concurrent.futures.Future() + self.requests.put((f, func, args, kwargs)) + return f.result(timeout=3) + return wrapper + + def requires_split(self): + return False + + def requires_upgrade(self): + return False + + def set_modified(self, b): + self._modified = b + + def modified(self): + return self._modified + + def dump(self): + with open(self.tmp_path, 'rb') as f: + return f.read() + + def load_addresses(self, x): + pass + + @sql + def put(self, key, v): + if key == 'keystore': + k = Keystore(name=key, type=v['type'], pw_hash_version=v['pw_hash_version'], seed=v.get('seed'), xpub=v.get('xpub'), xprv=v.get('xprv')) + else: + k = Misc(key=key, value=json.dumps(v)) + self.DBSession.add(k) + self.DBSession.commit() + + @sql + def get(self, key, default=None): + if key == 'keystore': + r = self.DBSession.query(Keystore).filter(Keystore.name == key).one_or_none() + if r: + r = {'type':r.type, 'pw_hash_version':r.pw_hash_version, 'seed':r.seed, 'xpub':r.xpub, 'xprv':r.xprv} + else: + r = self.DBSession.query(Misc).filter(Misc.key == key).one_or_none() + if r: + r = json.loads(r.value) + return r or default + + @sql + def add_receiving_address(self, address): + self.DBSession.add(ReceivingAddress(address=address)) + self.DBSession.commit() + + @sql + def add_change_address(self, address): + self.DBSession.add(ChangeAddress(address=address)) + self.DBSession.commit() + + @sql + def get_receiving_addresses(self): + return list(map(itemgetter(0), self.DBSession.query(ReceivingAddress.address).all())) + + @sql + def get_change_addresses(self): + return list(map(itemgetter(0), self.DBSession.query(ChangeAddress.address).all())) + + @sql + def num_receiving_addresses(self): + return self.DBSession.query(ReceivingAddress).count() + + @sql + def num_change_addresses(self): + return self.DBSession.query(ChangeAddress).count() + + @sql + def get_address_index(self, address): + r = self.DBSession.query(ReceivingAddress).filter(ReceivingAddress.address==address).one_or_none() + if r: + return False, r.id - 1 + r = self.DBSession.query(ChangeAddress).filter(ChangeAddress.address==address).one_or_none() + if r: + return True, r.id - 1 + r = self.DBSession.query(ImportedAddress).filter(ImportedAddress.address==address).one_or_none() + if r: + return r.pubkey + + @sql + def add_txi_addr(self, txid, address, ser, v): + if self.DBSession.query(Txi).filter(Txi.txid==txid, Txi.address==address, Txi.ser==ser, Txi.v==v).one_or_none(): + return + self.DBSession.add(Txi(txid=txid, address=address, ser=ser, v=v)) + self.DBSession.commit() + + @sql + def add_txo_addr(self, txid, address, n, v, is_coinbase): + if self.DBSession.query(Txo).filter(Txo.txid==txid, Txo.address==address, Txo.n==n, Txo.v==v, Txo.is_coinbase==is_coinbase).one_or_none(): + return + self.DBSession.add(Txo(txid=txid, address=address, n=n, v=v, is_coinbase=is_coinbase)) + self.DBSession.commit() + + @sql + def get_txi(self, txid): + return set(map(itemgetter(0), self.DBSession.query(Txi.address).filter(Txi.txid==txid).all())) + + @sql + def get_txo(self, txid): + return set(map(itemgetter(0), self.DBSession.query(Txo.address).filter(Txo.txid==txid).all())) + + @sql + def get_txi_addr(self, txid, address): + return [(r.ser, r.v) for r in self.DBSession.query(Txi).filter(Txi.txid==txid, Txi.address==address).all()] + + @sql + def get_txo_addr(self, txid, address): + return [(r.n, r.v, r.is_coinbase) for r in self.DBSession.query(Txo).filter(Txo.txid==txid, Txo.address==address).all()] + + @sql + def list_txi(self): + return set(map(itemgetter(0), self.DBSession.query(Txi.txid).all())) + + @sql + def list_txo(self): + return set(map(itemgetter(0), self.DBSession.query(Txo.txid).all())) + + @sql + def get_history(self): + return set(map(itemgetter(0), self.DBSession.query(AddressHistory.address).all())) + + @sql + def get_addr_history(self, address): + return [(x.txid, x.height) for x in self.DBSession.query(AddressHistory).filter(AddressHistory.address==address).all()] + + @sql + def set_addr_history(self, address, hist): + for txid, height in hist: + self.DBSession.add(AddressHistory(address=address, txid=txid, height=height)) + self.DBSession.commit() + + @sql + def is_in_verified_tx(self, txid): + return bool(self.DBSession.query(VerifiedTx).filter(VerifiedTx.txid==txid).one_or_none()) + + @sql + def add_verified_tx(self, txid, info): + self.DBSession.add(VerifiedTx(txid=txid, height=info.height, timestamp=info.timestamp, txpos=info.txpos, header_hash=info.header_hash)) + self.DBSession.commit() + + @sql + def get_verified_tx(self, txid): + r = self.DBSession.query(VerifiedTx).filter(VerifiedTx.txid==txid).one_or_none() + if not r: + return + return TxMinedInfo(height=r.height, + conf=None, + timestamp=r.timestamp, + txpos=r.txpos, + header_hash=r.header_hash) + + @sql + def remove_verified_tx(self, txid): + v = self.DBSession.query(VerifiedTx).filter(VerifiedTx.txid==txid).one_or_none() + self.DBSession.delete(v) + self.DBSession.commit() + + @sql + def get_spent_outpoints(self, prev_txid): + return [r.prev_n for r in self.DBSession.query(Spent).filter(Spent.prev_txid==prev_txid).all()] + + @sql + def get_spent_outpoint(self, prev_txid, prev_n): + r = self.DBSession.query(Spent).filter(Spent.prev_txid==prev_txid, Spent.prev_n==prev_n).one_or_none() + return r.ser if r else None + + @sql + def set_spent_outpoint(self, prev_txid, prev_n, ser): + r = self.DBSession.query(Spent).filter(Spent.prev_txid==prev_txid, Spent.prev_n==prev_n).one_or_none() + if r: + return + self.DBSession.add(Spent(prev_txid=prev_txid, prev_n=prev_n, ser=ser)) + self.DBSession.commit() + + @sql + def get_transaction(self, txid): + r = self.DBSession.query(Transaction).filter(Transaction.txid==txid).one_or_none() + return transaction.Transaction(r.raw) if r else None + + @sql + def add_transaction(self, txid, tx): + self.DBSession.add(Transaction(txid=txid, raw=str(tx))) + self.DBSession.commit() + + @sql + def update_tx_fees(self, tx_fees): + pass + + @sql + def get_tx_fee(self, tx_hash): + pass diff --git a/electrum/storage.py b/electrum/storage.py index 9aef8a119..161068c39 100644 --- a/electrum/storage.py +++ b/electrum/storage.py @@ -33,6 +33,7 @@ from . import ecc from .util import PrintError, profiler, InvalidPassword, WalletFileException, bfh, standardize_path from .plugin import run_hook, plugin_loaders +from .sqlite_db import SqliteDB from .json_db import JsonDB @@ -48,16 +49,15 @@ STO_EV_PLAINTEXT, STO_EV_USER_PW, STO_EV_XPUB_PW = range(0, 3) class WalletStorage(PrintError): - def __init__(self, path, *, manual_upgrades=False): + def __init__(self, path, *, manual_upgrades=False, use_sql=False): self.lock = threading.RLock() self.path = standardize_path(path) self._file_exists = self.path and os.path.exists(self.path) - - DB_Class = JsonDB + DB_Class = SqliteDB# if use_sql else JsonDB self.print_error("wallet path", self.path) self.pubkey = None if self.file_exists(): - with open(self.path, "r", encoding='utf-8') as f: + with open(self.path, "rb") as f: self.raw = f.read() self._encryption_version = self._init_encryption_version() if not self.is_encrypted(): @@ -66,7 +66,7 @@ class WalletStorage(PrintError): else: self._encryption_version = STO_EV_PLAINTEXT # avoid new wallets getting 'upgraded' - self.db = DB_Class('', manual_upgrades=False) + self.db = DB_Class(b'', manual_upgrades=False) def load_plugins(self): @@ -91,10 +91,10 @@ class WalletStorage(PrintError): return if not self.db.modified(): return - self.db.commit() + s = self.encrypt_before_writing(self.db.dump()) temp_path = "%s.tmp.%s" % (self.path, os.getpid()) - with open(temp_path, "w", encoding='utf-8') as f: + with open(temp_path, "wb") as f: f.write(s) f.flush() os.fsync(f.fileno()) @@ -219,8 +219,8 @@ class WalletStorage(PrintError): self.db.set_modified(True) def requires_upgrade(self): - if not self.is_past_initial_decryption(): - raise Exception("storage not yet decrypted!") + #if not self.is_past_initial_decryption(): + # raise Exception("storage not yet decrypted!") return self.db.requires_upgrade() def upgrade(self):