From b861e2e955c4a790d8e2b4ce262b894a67c3b470 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Tue, 5 Mar 2019 17:28:24 +0100 Subject: [PATCH] lnwatcher: save sweepstore in sqlite database --- electrum/gui/qt/watchtower_window.py | 6 +- electrum/lnwatcher.py | 161 ++++++++++++++++++++------- 2 files changed, 124 insertions(+), 43 deletions(-) diff --git a/electrum/gui/qt/watchtower_window.py b/electrum/gui/qt/watchtower_window.py index fe24c47bb..6967e6f91 100644 --- a/electrum/gui/qt/watchtower_window.py +++ b/electrum/gui/qt/watchtower_window.py @@ -52,9 +52,11 @@ class WatcherList(MyTreeView): def update(self): self.model().clear() self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')}) - for outpoint, sweep_dict in self.parent.lnwatcher.sweepstore.items(): + sweepstore = self.parent.lnwatcher.sweepstore + for outpoint in sweepstore.list_sweep_tx(): + n = sweepstore.num_sweep_tx(outpoint) status = self.parent.lnwatcher.get_channel_status(outpoint) - items = [QStandardItem(e) for e in [outpoint, "%d"%len(sweep_dict), status]] + items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]] self.model().insertRow(self.model().rowCount(), items) diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 96f60ba13..968bba0eb 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -2,9 +2,11 @@ # Distributed under the MIT software license, see the accompanying # file LICENCE or http://www.opensource.org/licenses/mit-license.php -import threading from typing import NamedTuple, Iterable, TYPE_CHECKING import os +import queue +import threading +import concurrent from collections import defaultdict import asyncio from enum import IntEnum, auto @@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum): FREE = auto() +from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean +from sqlalchemy.pool import StaticPool +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.query import Query +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import not_, or_ +from sqlalchemy.orm import scoped_session + +Base = declarative_base() + +class SweepTx(Base): + __tablename__ = 'sweep_txs' + funding_outpoint = Column(String(34)) + prev_txid = Column(String(32)) + tx = Column(String()) + txid = Column(String(32), primary_key=True) # txid of tx + +class ChannelInfo(Base): + __tablename__ = 'channel_info' + address = Column(String(32), primary_key=True) + outpoint = Column(String(34)) + + +class SweepStore(PrintError): + + def __init__(self, path, network): + PrintError.__init__(self) + self.path = path + self.network = network + self.db_requests = queue.Queue() + threading.Thread(target=self.sql_thread).start() + + def sql_thread(self): + engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool) + DBSession = sessionmaker(bind=engine, autoflush=False) + self.DBSession = DBSession() + if not os.path.exists(self.path): + Base.metadata.create_all(engine) + while self.network.asyncio_loop.is_running(): + try: + future, func, args, kwargs = self.db_requests.get(timeout=0.1) + except queue.Empty: + continue + try: + result = func(self, *args, **kwargs) + except BaseException as e: + future.set_exception(e) + continue + future.set_result(result) + # write + self.DBSession.commit() + self.print_error("SQL thread terminated") + + def sql(func): + def wrapper(self, *args, **kwargs): + f = concurrent.futures.Future() + self.db_requests.put((f, func, args, kwargs)) + return f.result(timeout=10) + return wrapper + + @sql + def get_sweep_tx(self, funding_outpoint, prev_txid): + return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()] + + @sql + def list_sweep_tx(self): + return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) + + @sql + def add_sweep_tx(self, funding_outpoint, prev_txid, tx): + self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid())) + self.DBSession.commit() + + @sql + def num_sweep_tx(self, funding_outpoint): + return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() + + @sql + def remove_sweep_tx(self, funding_outpoint): + v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all() + self.DBSession.delete(v) + self.DBSession.commit() + + @sql + def add_channel_info(self, address, outpoint): + self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) + self.DBSession.commit() + + @sql + def remove_channel_info(self, address): + v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() + self.DBSession.delete(v) + self.DBSession.commit() + + @sql + def has_channel_info(self, address): + return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()) + + @sql + def get_channel_info(self, address): + r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() + return r.outpoint if r else None + + @sql + def list_channel_info(self): + return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()] + + class LNWatcher(AddressSynchronizer): verbosity_filter = 'W' def __init__(self, network: 'Network'): - path = os.path.join(network.config.path, "watcher_db") + path = os.path.join(network.config.path, "watchtower_wallet") storage = WalletStorage(path) AddressSynchronizer.__init__(self, storage) self.config = network.config self.start_network(network) self.lock = threading.RLock() - self.channel_info = storage.get('channel_info', {}) # access with 'lock' - # [funding_outpoint_str][prev_txid] -> set of Transaction - # prev_txid is the txid of a tx that is watched for confirmations - # access with 'lock' - self.sweepstore = defaultdict(lambda: defaultdict(set)) - for funding_outpoint, ctxs in storage.get('sweepstore', {}).items(): - for txid, set_of_txns in ctxs.items(): - for tx in set_of_txns: - tx2 = Transaction.from_dict(tx) - self.sweepstore[funding_outpoint][txid].add(tx2) - + self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network) self.network.register_callback(self.on_network_update, ['network_updated', 'blockchain_updated', 'verified', 'wallet_updated']) self.set_remote_watchtower() @@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer): await asyncio.sleep(5) await self.watchtower_queue.put((name, args, kwargs)) - def write_to_disk(self): - # FIXME: json => every update takes linear instead of constant disk write - with self.lock: - storage = self.storage - storage.put('channel_info', self.channel_info) - # self.sweepstore - sweepstore = {} - for funding_outpoint, ctxs in self.sweepstore.items(): - sweepstore[funding_outpoint] = {} - for prev_txid, set_of_txns in ctxs.items(): - sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns] - storage.put('sweepstore', sweepstore) - storage.write() @with_watchtower def watch_channel(self, address, outpoint): self.add_address(address) with self.lock: - if address not in self.channel_info: - self.channel_info[address] = outpoint - self.write_to_disk() + if not self.sweepstore.has_channel_info(address): + self.sweepstore.add_channel_info(address, outpoint) def unwatch_channel(self, address, funding_outpoint): self.print_error('unwatching', funding_outpoint) - with self.lock: - self.channel_info.pop(address) - self.sweepstore.pop(funding_outpoint) - self.write_to_disk() + self.sweepstore.remove_sweep_tx(funding_outpoint) + self.sweepstore.remove_channel_info(address) if funding_outpoint in self.tx_progress: self.tx_progress[funding_outpoint].all_done.set() @@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer): return if not self.synchronizer.is_up_to_date(): return - with self.lock: - channel_info_items = list(self.channel_info.items()) - for address, outpoint in channel_info_items: + for address, outpoint in self.sweepstore.list_channel_info(): await self.check_onchain_situation(address, outpoint) async def check_onchain_situation(self, address, funding_outpoint): @@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer): if spender is not None: continue prev_txid, prev_n = prevout.split(':') - with self.lock: - sweep_txns = self.sweepstore[funding_outpoint][prev_txid] + sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid) for tx in sweep_txns: if not await self.broadcast_or_log(funding_outpoint, tx): self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}') @@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer): @with_watchtower def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict): tx = Transaction.from_dict(tx_dict) - with self.lock: - self.sweepstore[funding_outpoint][prev_txid].add(tx) - self.write_to_disk() + self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx) def get_tx_mined_depth(self, txid: str): if not txid: