Browse Source

lnwatcher: save sweepstore in sqlite database

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
b861e2e955
  1. 6
      electrum/gui/qt/watchtower_window.py
  2. 161
      electrum/lnwatcher.py

6
electrum/gui/qt/watchtower_window.py

@ -52,9 +52,11 @@ class WatcherList(MyTreeView):
def update(self):
self.model().clear()
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
for outpoint, sweep_dict in self.parent.lnwatcher.sweepstore.items():
sweepstore = self.parent.lnwatcher.sweepstore
for outpoint in sweepstore.list_sweep_tx():
n = sweepstore.num_sweep_tx(outpoint)
status = self.parent.lnwatcher.get_channel_status(outpoint)
items = [QStandardItem(e) for e in [outpoint, "%d"%len(sweep_dict), status]]
items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
self.model().insertRow(self.model().rowCount(), items)

161
electrum/lnwatcher.py

@ -2,9 +2,11 @@
# Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
import threading
from typing import NamedTuple, Iterable, TYPE_CHECKING
import os
import queue
import threading
import concurrent
from collections import defaultdict
import asyncio
from enum import IntEnum, auto
@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum):
FREE = auto()
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
Base = declarative_base()
class SweepTx(Base):
__tablename__ = 'sweep_txs'
funding_outpoint = Column(String(34))
prev_txid = Column(String(32))
tx = Column(String())
txid = Column(String(32), primary_key=True) # txid of tx
class ChannelInfo(Base):
__tablename__ = 'channel_info'
address = Column(String(32), primary_key=True)
outpoint = Column(String(34))
class SweepStore(PrintError):
def __init__(self, path, network):
PrintError.__init__(self)
self.path = path
self.network = network
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
@sql
def get_sweep_tx(self, funding_outpoint, prev_txid):
return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
@sql
def list_sweep_tx(self):
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
@sql
def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid()))
self.DBSession.commit()
@sql
def num_sweep_tx(self, funding_outpoint):
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
@sql
def remove_sweep_tx(self, funding_outpoint):
v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
self.DBSession.delete(v)
self.DBSession.commit()
@sql
def add_channel_info(self, address, outpoint):
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
self.DBSession.commit()
@sql
def remove_channel_info(self, address):
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
self.DBSession.delete(v)
self.DBSession.commit()
@sql
def has_channel_info(self, address):
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
@sql
def get_channel_info(self, address):
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
return r.outpoint if r else None
@sql
def list_channel_info(self):
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
class LNWatcher(AddressSynchronizer):
verbosity_filter = 'W'
def __init__(self, network: 'Network'):
path = os.path.join(network.config.path, "watcher_db")
path = os.path.join(network.config.path, "watchtower_wallet")
storage = WalletStorage(path)
AddressSynchronizer.__init__(self, storage)
self.config = network.config
self.start_network(network)
self.lock = threading.RLock()
self.channel_info = storage.get('channel_info', {}) # access with 'lock'
# [funding_outpoint_str][prev_txid] -> set of Transaction
# prev_txid is the txid of a tx that is watched for confirmations
# access with 'lock'
self.sweepstore = defaultdict(lambda: defaultdict(set))
for funding_outpoint, ctxs in storage.get('sweepstore', {}).items():
for txid, set_of_txns in ctxs.items():
for tx in set_of_txns:
tx2 = Transaction.from_dict(tx)
self.sweepstore[funding_outpoint][txid].add(tx2)
self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network)
self.network.register_callback(self.on_network_update,
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
self.set_remote_watchtower()
@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer):
await asyncio.sleep(5)
await self.watchtower_queue.put((name, args, kwargs))
def write_to_disk(self):
# FIXME: json => every update takes linear instead of constant disk write
with self.lock:
storage = self.storage
storage.put('channel_info', self.channel_info)
# self.sweepstore
sweepstore = {}
for funding_outpoint, ctxs in self.sweepstore.items():
sweepstore[funding_outpoint] = {}
for prev_txid, set_of_txns in ctxs.items():
sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns]
storage.put('sweepstore', sweepstore)
storage.write()
@with_watchtower
def watch_channel(self, address, outpoint):
self.add_address(address)
with self.lock:
if address not in self.channel_info:
self.channel_info[address] = outpoint
self.write_to_disk()
if not self.sweepstore.has_channel_info(address):
self.sweepstore.add_channel_info(address, outpoint)
def unwatch_channel(self, address, funding_outpoint):
self.print_error('unwatching', funding_outpoint)
with self.lock:
self.channel_info.pop(address)
self.sweepstore.pop(funding_outpoint)
self.write_to_disk()
self.sweepstore.remove_sweep_tx(funding_outpoint)
self.sweepstore.remove_channel_info(address)
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer):
return
if not self.synchronizer.is_up_to_date():
return
with self.lock:
channel_info_items = list(self.channel_info.items())
for address, outpoint in channel_info_items:
for address, outpoint in self.sweepstore.list_channel_info():
await self.check_onchain_situation(address, outpoint)
async def check_onchain_situation(self, address, funding_outpoint):
@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer):
if spender is not None:
continue
prev_txid, prev_n = prevout.split(':')
with self.lock:
sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid)
for tx in sweep_txns:
if not await self.broadcast_or_log(funding_outpoint, tx):
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer):
@with_watchtower
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
tx = Transaction.from_dict(tx_dict)
with self.lock:
self.sweepstore[funding_outpoint][prev_txid].add(tx)
self.write_to_disk()
self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
def get_tx_mined_depth(self, txid: str):
if not txid:

Loading…
Cancel
Save