Browse Source

lnwatcher: save sweepstore in sqlite database

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
b861e2e955
  1. 6
      electrum/gui/qt/watchtower_window.py
  2. 161
      electrum/lnwatcher.py

6
electrum/gui/qt/watchtower_window.py

@ -52,9 +52,11 @@ class WatcherList(MyTreeView):
def update(self): def update(self):
self.model().clear() self.model().clear()
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')}) self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
for outpoint, sweep_dict in self.parent.lnwatcher.sweepstore.items(): sweepstore = self.parent.lnwatcher.sweepstore
for outpoint in sweepstore.list_sweep_tx():
n = sweepstore.num_sweep_tx(outpoint)
status = self.parent.lnwatcher.get_channel_status(outpoint) status = self.parent.lnwatcher.get_channel_status(outpoint)
items = [QStandardItem(e) for e in [outpoint, "%d"%len(sweep_dict), status]] items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
self.model().insertRow(self.model().rowCount(), items) self.model().insertRow(self.model().rowCount(), items)

161
electrum/lnwatcher.py

@ -2,9 +2,11 @@
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php # file LICENCE or http://www.opensource.org/licenses/mit-license.php
import threading
from typing import NamedTuple, Iterable, TYPE_CHECKING from typing import NamedTuple, Iterable, TYPE_CHECKING
import os import os
import queue
import threading
import concurrent
from collections import defaultdict from collections import defaultdict
import asyncio import asyncio
from enum import IntEnum, auto from enum import IntEnum, auto
@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum):
FREE = auto() FREE = auto()
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
Base = declarative_base()
class SweepTx(Base):
__tablename__ = 'sweep_txs'
funding_outpoint = Column(String(34))
prev_txid = Column(String(32))
tx = Column(String())
txid = Column(String(32), primary_key=True) # txid of tx
class ChannelInfo(Base):
__tablename__ = 'channel_info'
address = Column(String(32), primary_key=True)
outpoint = Column(String(34))
class SweepStore(PrintError):
def __init__(self, path, network):
PrintError.__init__(self)
self.path = path
self.network = network
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
@sql
def get_sweep_tx(self, funding_outpoint, prev_txid):
return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
@sql
def list_sweep_tx(self):
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
@sql
def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid()))
self.DBSession.commit()
@sql
def num_sweep_tx(self, funding_outpoint):
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
@sql
def remove_sweep_tx(self, funding_outpoint):
v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
self.DBSession.delete(v)
self.DBSession.commit()
@sql
def add_channel_info(self, address, outpoint):
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
self.DBSession.commit()
@sql
def remove_channel_info(self, address):
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
self.DBSession.delete(v)
self.DBSession.commit()
@sql
def has_channel_info(self, address):
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
@sql
def get_channel_info(self, address):
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
return r.outpoint if r else None
@sql
def list_channel_info(self):
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
class LNWatcher(AddressSynchronizer): class LNWatcher(AddressSynchronizer):
verbosity_filter = 'W' verbosity_filter = 'W'
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
path = os.path.join(network.config.path, "watcher_db") path = os.path.join(network.config.path, "watchtower_wallet")
storage = WalletStorage(path) storage = WalletStorage(path)
AddressSynchronizer.__init__(self, storage) AddressSynchronizer.__init__(self, storage)
self.config = network.config self.config = network.config
self.start_network(network) self.start_network(network)
self.lock = threading.RLock() self.lock = threading.RLock()
self.channel_info = storage.get('channel_info', {}) # access with 'lock' self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network)
# [funding_outpoint_str][prev_txid] -> set of Transaction
# prev_txid is the txid of a tx that is watched for confirmations
# access with 'lock'
self.sweepstore = defaultdict(lambda: defaultdict(set))
for funding_outpoint, ctxs in storage.get('sweepstore', {}).items():
for txid, set_of_txns in ctxs.items():
for tx in set_of_txns:
tx2 = Transaction.from_dict(tx)
self.sweepstore[funding_outpoint][txid].add(tx2)
self.network.register_callback(self.on_network_update, self.network.register_callback(self.on_network_update,
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated']) ['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
self.set_remote_watchtower() self.set_remote_watchtower()
@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer):
await asyncio.sleep(5) await asyncio.sleep(5)
await self.watchtower_queue.put((name, args, kwargs)) await self.watchtower_queue.put((name, args, kwargs))
def write_to_disk(self):
# FIXME: json => every update takes linear instead of constant disk write
with self.lock:
storage = self.storage
storage.put('channel_info', self.channel_info)
# self.sweepstore
sweepstore = {}
for funding_outpoint, ctxs in self.sweepstore.items():
sweepstore[funding_outpoint] = {}
for prev_txid, set_of_txns in ctxs.items():
sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns]
storage.put('sweepstore', sweepstore)
storage.write()
@with_watchtower @with_watchtower
def watch_channel(self, address, outpoint): def watch_channel(self, address, outpoint):
self.add_address(address) self.add_address(address)
with self.lock: with self.lock:
if address not in self.channel_info: if not self.sweepstore.has_channel_info(address):
self.channel_info[address] = outpoint self.sweepstore.add_channel_info(address, outpoint)
self.write_to_disk()
def unwatch_channel(self, address, funding_outpoint): def unwatch_channel(self, address, funding_outpoint):
self.print_error('unwatching', funding_outpoint) self.print_error('unwatching', funding_outpoint)
with self.lock: self.sweepstore.remove_sweep_tx(funding_outpoint)
self.channel_info.pop(address) self.sweepstore.remove_channel_info(address)
self.sweepstore.pop(funding_outpoint)
self.write_to_disk()
if funding_outpoint in self.tx_progress: if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set() self.tx_progress[funding_outpoint].all_done.set()
@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer):
return return
if not self.synchronizer.is_up_to_date(): if not self.synchronizer.is_up_to_date():
return return
with self.lock: for address, outpoint in self.sweepstore.list_channel_info():
channel_info_items = list(self.channel_info.items())
for address, outpoint in channel_info_items:
await self.check_onchain_situation(address, outpoint) await self.check_onchain_situation(address, outpoint)
async def check_onchain_situation(self, address, funding_outpoint): async def check_onchain_situation(self, address, funding_outpoint):
@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer):
if spender is not None: if spender is not None:
continue continue
prev_txid, prev_n = prevout.split(':') prev_txid, prev_n = prevout.split(':')
with self.lock: sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid)
sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
for tx in sweep_txns: for tx in sweep_txns:
if not await self.broadcast_or_log(funding_outpoint, tx): if not await self.broadcast_or_log(funding_outpoint, tx):
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}') self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer):
@with_watchtower @with_watchtower
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict): def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
tx = Transaction.from_dict(tx_dict) tx = Transaction.from_dict(tx_dict)
with self.lock: self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
self.sweepstore[funding_outpoint][prev_txid].add(tx)
self.write_to_disk()
def get_tx_mined_depth(self, txid: str): def get_tx_mined_depth(self, txid: str):
if not txid: if not txid:

Loading…
Cancel
Save