Browse Source

storage_db: fix tests, add modified flag to db class

sqlite_db
ThomasV 6 years ago
parent
commit
d74f0c0947
  1. 25
      electrum/address_synchronizer.py
  2. 43
      electrum/json_db.py
  3. 18
      electrum/storage.py
  4. 2
      electrum/tests/test_wallet_vertical.py
  5. 9
      electrum/wallet.py

25
electrum/address_synchronizer.py

@ -282,27 +282,20 @@ class AddressSynchronizer(PrintError):
def remove_transaction(self, tx_hash):
def remove_from_spent_outpoints():
# undo spends in spent_outpoints
if tx is not None: # if we have the tx, this branch is faster
if tx is not None:
# if we have the tx, this branch is faster
for txin in tx.inputs():
if txin['type'] == 'coinbase':
continue
prevout_hash = txin['prevout_hash']
prevout_n = txin['prevout_n']
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME
if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash)
else: # expensive but always works
for prevout_hash, d in list(self.spent_outpoints.items()):
for prevout_n, spending_txid in d.items():
if spending_txid == tx_hash:
self.spent_outpoints[prevout_hash].pop(prevout_n, None)
if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash)
# Remove this tx itself; if nothing spends from it.
# It is not so clear what to do if other txns spend from it, but it will be
# removed when those other txns are removed.
if not self.spent_outpoints[tx_hash]:
self.spent_outpoints.pop(tx_hash)
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
else:
# expensive but always works
for prevout_hash, prevout_n in list(self.db.list_spent_outpoints()):
spending_txid = self.db.get_spent_outpoint(prevout_hash, prevout_n)
if spending_txid == tx_hash:
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
with self.transaction_lock:
self.print_error("removing tx from history", tx_hash)

43
electrum/json_db.py

@ -26,6 +26,7 @@ import os
import ast
import json
import copy
import threading
from collections import defaultdict
from typing import Dict
@ -45,7 +46,9 @@ FINAL_SEED_VERSION = 18 # electrum >= 2.7 will set this to prevent
class JsonDB(PrintError):
def __init__(self, raw, *, manual_upgrades):
self.lock = threading.RLock()
self.data = {}
self._modified = False
self.manual_upgrades = manual_upgrades
if raw:
self.load_data(raw)
@ -53,6 +56,20 @@ class JsonDB(PrintError):
self.put('seed_version', FINAL_SEED_VERSION)
self.load_transactions()
def set_modified(self, b):
with self.lock:
self._modified = b
def modified(self):
return self._modified
def modifier(func):
def wrapper(self, *args, **kwargs):
with self.lock:
self._modified = True
return func(self, *args, **kwargs)
return wrapper
def get(self, key, default=None):
v = self.data.get(key)
if v is None:
@ -61,6 +78,7 @@ class JsonDB(PrintError):
v = copy.deepcopy(v)
return v
@modifier
def put(self, key, value):
try:
json.dumps(key, cls=util.MyEncoder)
@ -483,6 +501,7 @@ class JsonDB(PrintError):
def get_txo_addr(self, tx_hash, address):
return self.txo.get(tx_hash, {}).get(address, [])
@modifier
def add_txi_addr(self, tx_hash, addr, ser, v):
if tx_hash not in self.txi:
self.txi[tx_hash] = {}
@ -492,6 +511,7 @@ class JsonDB(PrintError):
d[addr] = set()
d[addr].add((ser, v))
@modifier
def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase):
if tx_hash not in self.txo:
self.txo[tx_hash] = {}
@ -507,26 +527,43 @@ class JsonDB(PrintError):
def get_txo_keys(self):
return self.txo.keys()
@modifier
def remove_txi(self, tx_hash):
self.txi.pop(tx_hash, None)
@modifier
def remove_txo(self, tx_hash):
self.txo.pop(tx_hash, None)
def list_spent_outpoints(self):
return [(h, n)
for h in self.spent_outpoints.keys()
for n in self.get_spent_outpoints(h)
]
def get_spent_outpoints(self, prevout_hash):
return self.spent_outpoints.get(prevout_hash, {}).keys()
def get_spent_outpoint(self, prevout_hash, prevout_n):
return self.spent_outpoints.get(prevout_hash, {}).get(str(prevout_n))
@modifier
def remove_spent_outpoint(self, prevout_hash, prevout_n):
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME
if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash)
@modifier
def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash):
if prevout_hash not in self.spent_outpoints:
self.spent_outpoints[prevout_hash] = {}
self.spent_outpoints[prevout_hash][str(prevout_n)] = tx_hash
@modifier
def add_transaction(self, tx_hash, tx):
self.transactions[tx_hash] = str(tx)
@modifier
def remove_transaction(self, tx_hash):
self.transactions.pop(tx_hash, None)
@ -543,9 +580,11 @@ class JsonDB(PrintError):
def get_addr_history(self, addr):
return self.history.get(addr, [])
@modifier
def set_addr_history(self, addr, hist):
self.history[addr] = hist
@modifier
def remove_addr_history(self, addr):
self.history.pop(addr, None)
@ -562,18 +601,22 @@ class JsonDB(PrintError):
txpos=txpos,
header_hash=header_hash)
@modifier
def add_verified_tx(self, txid, info):
self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash)
@modifier
def remove_verified_tx(self, txid):
self.verified_tx.pop(txid, None)
@modifier
def update_tx_fees(self, d):
return self.tx_fees.update(d)
def get_tx_fee(self, txid):
return self.tx_fees.get(txid)
@modifier
def remove_tx_fee(self, txid):
self.tx_fees.pop(txid, None)

18
electrum/storage.py

@ -49,10 +49,9 @@ STO_EV_PLAINTEXT, STO_EV_USER_PW, STO_EV_XPUB_PW = range(0, 3)
class WalletStorage(PrintError):
def __init__(self, path, *, manual_upgrades=False):
self.db_lock = threading.RLock()
self.lock = threading.RLock()
self.path = standardize_path(path)
self._file_exists = self.path and os.path.exists(self.path)
self.modified = False
DB_Class = JsonDB
self.path = path
@ -70,23 +69,21 @@ class WalletStorage(PrintError):
self.db = DB_Class('', manual_upgrades=False)
def put(self, key,value):
with self.db_lock:
self.modified |= self.db.put(key, value)
self.db.put(key, value)
def get(self, key, default=None):
with self.db_lock:
return self.db.get(key, default)
return self.db.get(key, default)
@profiler
def write(self):
with self.db_lock:
with self.lock:
self._write()
def _write(self):
if threading.currentThread().isDaemon():
self.print_error('warning: daemon thread cannot write db')
return
if not self.modified:
if not self.db.modified():
return
self.db.commit()
s = self.encrypt_before_writing(self.db.dump())
@ -103,7 +100,7 @@ class WalletStorage(PrintError):
os.chmod(self.path, mode)
self._file_exists = True
self.print_error("saved", self.path)
self.modified = False
self.db.set_modified(False)
def file_exists(self):
return self._file_exists
@ -209,8 +206,7 @@ class WalletStorage(PrintError):
self.pubkey = None
self._encryption_version = STO_EV_PLAINTEXT
# make sure next storage.write() saves changes
with self.db_lock:
self.modified = True
self.db.set_modified(True)
def requires_upgrade(self):
return self.db.requires_upgrade()

2
electrum/tests/test_wallet_vertical.py

@ -1719,7 +1719,7 @@ class TestWalletHistory_EvilGapLimit(TestCaseForTestnet):
w.storage.put('stored_height', 1316917 + 100)
for txid in self.transactions:
tx = Transaction(self.transactions[txid])
w.transactions[tx.txid()] = tx
w.add_transaction(tx.txid(), tx)
# txn A is an external incoming txn paying to addr (3) and (15)
# txn B is an external incoming txn paying to addr (4) and (25)
# txn C is an internal transfer txn from addr (25) -- to -- (1) and (25)

9
electrum/wallet.py

@ -1201,7 +1201,6 @@ class Abstract_Wallet(AddressSynchronizer):
self._update_password_for_keystore(old_pw, new_pw)
encrypt_keystore = self.can_have_keystore_encryption()
self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
self.storage.write()
def sign_message(self, address, message, password):
@ -1385,7 +1384,6 @@ class Imported_Wallet(Simple_Wallet):
self.addresses[address] = {}
self.add_address(address)
self.save_addresses()
self.save_transactions(write=write_to_disk)
return good_addr, bad_addr
def import_address(self, address: str) -> str:
@ -1398,7 +1396,6 @@ class Imported_Wallet(Simple_Wallet):
def delete_address(self, address):
if address not in self.addresses:
return
transactions_to_remove = set() # only referred to by this address
transactions_new = set() # txs that are not only referred to by address
with self.lock:
@ -1412,20 +1409,15 @@ class Imported_Wallet(Simple_Wallet):
transactions_new.add(tx_hash)
transactions_to_remove -= transactions_new
self.db.remove_history(address)
for tx_hash in transactions_to_remove:
self.remove_transaction(tx_hash)
self.db.remove_tx_fee(tx_hash)
self.db.remove_verified_tx(tx_hash)
self.unverified_tx.pop(tx_hash, None)
self.db.remove_transaction(tx_hash)
self.save_verified_tx()
self.save_transactions()
self.set_label(address, None)
self.remove_payment_request(address, {})
self.set_frozen_state([address], False)
pubkey = self.get_public_key(address)
self.addresses.pop(address)
if pubkey:
@ -1442,7 +1434,6 @@ class Imported_Wallet(Simple_Wallet):
self.keystore.delete_imported_key(pubkey)
self.save_keystore()
self.save_addresses()
self.storage.write()
def get_address_index(self, address):

Loading…
Cancel
Save