|
|
@ -2,9 +2,11 @@ |
|
|
|
# Distributed under the MIT software license, see the accompanying |
|
|
|
# file LICENCE or http://www.opensource.org/licenses/mit-license.php |
|
|
|
|
|
|
|
import threading |
|
|
|
from typing import NamedTuple, Iterable, TYPE_CHECKING |
|
|
|
import os |
|
|
|
import queue |
|
|
|
import threading |
|
|
|
import concurrent |
|
|
|
from collections import defaultdict |
|
|
|
import asyncio |
|
|
|
from enum import IntEnum, auto |
|
|
@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum): |
|
|
|
FREE = auto() |
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean |
|
|
|
from sqlalchemy.pool import StaticPool |
|
|
|
from sqlalchemy.orm import sessionmaker |
|
|
|
from sqlalchemy.orm.query import Query |
|
|
|
from sqlalchemy.ext.declarative import declarative_base |
|
|
|
from sqlalchemy.sql import not_, or_ |
|
|
|
from sqlalchemy.orm import scoped_session |
|
|
|
|
|
|
|
Base = declarative_base() |
|
|
|
|
|
|
|
class SweepTx(Base): |
|
|
|
__tablename__ = 'sweep_txs' |
|
|
|
funding_outpoint = Column(String(34)) |
|
|
|
prev_txid = Column(String(32)) |
|
|
|
tx = Column(String()) |
|
|
|
txid = Column(String(32), primary_key=True) # txid of tx |
|
|
|
|
|
|
|
class ChannelInfo(Base): |
|
|
|
__tablename__ = 'channel_info' |
|
|
|
address = Column(String(32), primary_key=True) |
|
|
|
outpoint = Column(String(34)) |
|
|
|
|
|
|
|
|
|
|
|
class SweepStore(PrintError): |
|
|
|
|
|
|
|
def __init__(self, path, network): |
|
|
|
PrintError.__init__(self) |
|
|
|
self.path = path |
|
|
|
self.network = network |
|
|
|
self.db_requests = queue.Queue() |
|
|
|
threading.Thread(target=self.sql_thread).start() |
|
|
|
|
|
|
|
def sql_thread(self): |
|
|
|
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool) |
|
|
|
DBSession = sessionmaker(bind=engine, autoflush=False) |
|
|
|
self.DBSession = DBSession() |
|
|
|
if not os.path.exists(self.path): |
|
|
|
Base.metadata.create_all(engine) |
|
|
|
while self.network.asyncio_loop.is_running(): |
|
|
|
try: |
|
|
|
future, func, args, kwargs = self.db_requests.get(timeout=0.1) |
|
|
|
except queue.Empty: |
|
|
|
continue |
|
|
|
try: |
|
|
|
result = func(self, *args, **kwargs) |
|
|
|
except BaseException as e: |
|
|
|
future.set_exception(e) |
|
|
|
continue |
|
|
|
future.set_result(result) |
|
|
|
# write |
|
|
|
self.DBSession.commit() |
|
|
|
self.print_error("SQL thread terminated") |
|
|
|
|
|
|
|
def sql(func): |
|
|
|
def wrapper(self, *args, **kwargs): |
|
|
|
f = concurrent.futures.Future() |
|
|
|
self.db_requests.put((f, func, args, kwargs)) |
|
|
|
return f.result(timeout=10) |
|
|
|
return wrapper |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_sweep_tx(self, funding_outpoint, prev_txid): |
|
|
|
return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()] |
|
|
|
|
|
|
|
@sql |
|
|
|
def list_sweep_tx(self): |
|
|
|
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) |
|
|
|
|
|
|
|
@sql |
|
|
|
def add_sweep_tx(self, funding_outpoint, prev_txid, tx): |
|
|
|
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid())) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def num_sweep_tx(self, funding_outpoint): |
|
|
|
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() |
|
|
|
|
|
|
|
@sql |
|
|
|
def remove_sweep_tx(self, funding_outpoint): |
|
|
|
v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all() |
|
|
|
self.DBSession.delete(v) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def add_channel_info(self, address, outpoint): |
|
|
|
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def remove_channel_info(self, address): |
|
|
|
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() |
|
|
|
self.DBSession.delete(v) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def has_channel_info(self, address): |
|
|
|
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()) |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_channel_info(self, address): |
|
|
|
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() |
|
|
|
return r.outpoint if r else None |
|
|
|
|
|
|
|
@sql |
|
|
|
def list_channel_info(self): |
|
|
|
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()] |
|
|
|
|
|
|
|
|
|
|
|
class LNWatcher(AddressSynchronizer): |
|
|
|
verbosity_filter = 'W' |
|
|
|
|
|
|
|
def __init__(self, network: 'Network'): |
|
|
|
path = os.path.join(network.config.path, "watcher_db") |
|
|
|
path = os.path.join(network.config.path, "watchtower_wallet") |
|
|
|
storage = WalletStorage(path) |
|
|
|
AddressSynchronizer.__init__(self, storage) |
|
|
|
self.config = network.config |
|
|
|
self.start_network(network) |
|
|
|
self.lock = threading.RLock() |
|
|
|
self.channel_info = storage.get('channel_info', {}) # access with 'lock' |
|
|
|
# [funding_outpoint_str][prev_txid] -> set of Transaction |
|
|
|
# prev_txid is the txid of a tx that is watched for confirmations |
|
|
|
# access with 'lock' |
|
|
|
self.sweepstore = defaultdict(lambda: defaultdict(set)) |
|
|
|
for funding_outpoint, ctxs in storage.get('sweepstore', {}).items(): |
|
|
|
for txid, set_of_txns in ctxs.items(): |
|
|
|
for tx in set_of_txns: |
|
|
|
tx2 = Transaction.from_dict(tx) |
|
|
|
self.sweepstore[funding_outpoint][txid].add(tx2) |
|
|
|
|
|
|
|
self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network) |
|
|
|
self.network.register_callback(self.on_network_update, |
|
|
|
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated']) |
|
|
|
self.set_remote_watchtower() |
|
|
@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer): |
|
|
|
await asyncio.sleep(5) |
|
|
|
await self.watchtower_queue.put((name, args, kwargs)) |
|
|
|
|
|
|
|
def write_to_disk(self): |
|
|
|
# FIXME: json => every update takes linear instead of constant disk write |
|
|
|
with self.lock: |
|
|
|
storage = self.storage |
|
|
|
storage.put('channel_info', self.channel_info) |
|
|
|
# self.sweepstore |
|
|
|
sweepstore = {} |
|
|
|
for funding_outpoint, ctxs in self.sweepstore.items(): |
|
|
|
sweepstore[funding_outpoint] = {} |
|
|
|
for prev_txid, set_of_txns in ctxs.items(): |
|
|
|
sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns] |
|
|
|
storage.put('sweepstore', sweepstore) |
|
|
|
storage.write() |
|
|
|
|
|
|
|
@with_watchtower |
|
|
|
def watch_channel(self, address, outpoint): |
|
|
|
self.add_address(address) |
|
|
|
with self.lock: |
|
|
|
if address not in self.channel_info: |
|
|
|
self.channel_info[address] = outpoint |
|
|
|
self.write_to_disk() |
|
|
|
if not self.sweepstore.has_channel_info(address): |
|
|
|
self.sweepstore.add_channel_info(address, outpoint) |
|
|
|
|
|
|
|
def unwatch_channel(self, address, funding_outpoint): |
|
|
|
self.print_error('unwatching', funding_outpoint) |
|
|
|
with self.lock: |
|
|
|
self.channel_info.pop(address) |
|
|
|
self.sweepstore.pop(funding_outpoint) |
|
|
|
self.write_to_disk() |
|
|
|
self.sweepstore.remove_sweep_tx(funding_outpoint) |
|
|
|
self.sweepstore.remove_channel_info(address) |
|
|
|
if funding_outpoint in self.tx_progress: |
|
|
|
self.tx_progress[funding_outpoint].all_done.set() |
|
|
|
|
|
|
@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer): |
|
|
|
return |
|
|
|
if not self.synchronizer.is_up_to_date(): |
|
|
|
return |
|
|
|
with self.lock: |
|
|
|
channel_info_items = list(self.channel_info.items()) |
|
|
|
for address, outpoint in channel_info_items: |
|
|
|
for address, outpoint in self.sweepstore.list_channel_info(): |
|
|
|
await self.check_onchain_situation(address, outpoint) |
|
|
|
|
|
|
|
async def check_onchain_situation(self, address, funding_outpoint): |
|
|
@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer): |
|
|
|
if spender is not None: |
|
|
|
continue |
|
|
|
prev_txid, prev_n = prevout.split(':') |
|
|
|
with self.lock: |
|
|
|
sweep_txns = self.sweepstore[funding_outpoint][prev_txid] |
|
|
|
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid) |
|
|
|
for tx in sweep_txns: |
|
|
|
if not await self.broadcast_or_log(funding_outpoint, tx): |
|
|
|
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}') |
|
|
@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer): |
|
|
|
@with_watchtower |
|
|
|
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict): |
|
|
|
tx = Transaction.from_dict(tx_dict) |
|
|
|
with self.lock: |
|
|
|
self.sweepstore[funding_outpoint][prev_txid].add(tx) |
|
|
|
self.write_to_disk() |
|
|
|
self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx) |
|
|
|
|
|
|
|
def get_tx_mined_depth(self, txid: str): |
|
|
|
if not txid: |
|
|
|