|
|
@ -46,15 +46,15 @@ Base = declarative_base() |
|
|
|
|
|
|
|
class SweepTx(Base): |
|
|
|
__tablename__ = 'sweep_txs' |
|
|
|
funding_outpoint = Column(String(34)) |
|
|
|
funding_outpoint = Column(String(34), primary_key=True) |
|
|
|
index = Column(Integer(), primary_key=True) |
|
|
|
prev_txid = Column(String(32)) |
|
|
|
tx = Column(String()) |
|
|
|
txid = Column(String(32), primary_key=True) # txid of tx |
|
|
|
|
|
|
|
class ChannelInfo(Base): |
|
|
|
__tablename__ = 'channel_info' |
|
|
|
address = Column(String(32), primary_key=True) |
|
|
|
outpoint = Column(String(34)) |
|
|
|
outpoint = Column(String(34), primary_key=True) |
|
|
|
address = Column(String(32)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -67,17 +67,23 @@ class SweepStore(SqlDB): |
|
|
|
def get_sweep_tx(self, funding_outpoint, prev_txid): |
|
|
|
return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()] |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_tx_by_index(self, funding_outpoint, index): |
|
|
|
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none() |
|
|
|
return r.prev_txid, bh2u(r.tx) |
|
|
|
|
|
|
|
@sql |
|
|
|
def list_sweep_tx(self): |
|
|
|
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) |
|
|
|
|
|
|
|
@sql |
|
|
|
def add_sweep_tx(self, funding_outpoint, prev_txid, tx): |
|
|
|
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=bfh(str(tx)), txid=tx.txid())) |
|
|
|
n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() |
|
|
|
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prev_txid=prev_txid, tx=bfh(tx))) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def num_sweep_tx(self, funding_outpoint): |
|
|
|
def get_num_tx(self, funding_outpoint): |
|
|
|
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() |
|
|
|
|
|
|
|
@sql |
|
|
@ -87,24 +93,24 @@ class SweepStore(SqlDB): |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def add_channel_info(self, address, outpoint): |
|
|
|
def add_channel(self, outpoint, address): |
|
|
|
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def remove_channel_info(self, address): |
|
|
|
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() |
|
|
|
def remove_channel(self, outpoint): |
|
|
|
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() |
|
|
|
self.DBSession.delete(v) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def has_channel_info(self, address): |
|
|
|
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()) |
|
|
|
def has_channel(self, outpoint): |
|
|
|
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()) |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_channel_info(self, address): |
|
|
|
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() |
|
|
|
return r.outpoint if r else None |
|
|
|
def get_address(self, outpoint): |
|
|
|
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() |
|
|
|
return r.address if r else None |
|
|
|
|
|
|
|
@sql |
|
|
|
def list_channel_info(self): |
|
|
@ -139,42 +145,46 @@ class LNWatcher(AddressSynchronizer): |
|
|
|
self.watchtower = jsonrpclib.Server(watchtower_url) if watchtower_url else None |
|
|
|
self.watchtower_queue = asyncio.Queue() |
|
|
|
|
|
|
|
def with_watchtower(func): |
|
|
|
def wrapper(self, *args, **kwargs): |
|
|
|
if self.watchtower: |
|
|
|
self.watchtower_queue.put_nowait((func.__name__, args, kwargs)) |
|
|
|
return func(self, *args, **kwargs) |
|
|
|
return wrapper |
|
|
|
def get_num_tx(self, outpoint): |
|
|
|
return self.sweepstore.get_num_tx(outpoint) |
|
|
|
|
|
|
|
@ignore_exceptions |
|
|
|
@log_exceptions |
|
|
|
async def watchtower_task(self): |
|
|
|
self.print_error('watchtower task started') |
|
|
|
# initial check |
|
|
|
for address, outpoint in self.sweepstore.list_channel_info(): |
|
|
|
await self.watchtower_queue.put(outpoint) |
|
|
|
while True: |
|
|
|
name, args, kwargs = await self.watchtower_queue.get() |
|
|
|
outpoint = await self.watchtower_queue.get() |
|
|
|
if self.watchtower is None: |
|
|
|
continue |
|
|
|
func = getattr(self.watchtower, name) |
|
|
|
# synchronize with remote |
|
|
|
try: |
|
|
|
r = func(*args, **kwargs) |
|
|
|
self.print_error("watchtower answer", r) |
|
|
|
except: |
|
|
|
self.print_error('could not reach watchtower, will retry in 5s', name, args) |
|
|
|
local_n = self.sweepstore.get_num_tx(outpoint) |
|
|
|
n = self.watchtower.get_num_tx(outpoint) |
|
|
|
if n == 0: |
|
|
|
address = self.sweepstore.get_address(outpoint) |
|
|
|
self.watchtower.add_channel(outpoint, address) |
|
|
|
self.print_error("sending %d transactions to watchtower"%(local_n - n)) |
|
|
|
for index in range(n, local_n): |
|
|
|
prev_txid, tx = self.sweepstore.get_tx_by_index(outpoint, index) |
|
|
|
self.watchtower.add_sweep_tx(outpoint, prev_txid, tx) |
|
|
|
except ConnectionRefusedError: |
|
|
|
self.print_error('could not reach watchtower, will retry in 5s') |
|
|
|
await asyncio.sleep(5) |
|
|
|
await self.watchtower_queue.put((name, args, kwargs)) |
|
|
|
|
|
|
|
await self.watchtower_queue.put(outpoint) |
|
|
|
|
|
|
|
@with_watchtower |
|
|
|
def watch_channel(self, address, outpoint): |
|
|
|
def add_channel(self, outpoint, address): |
|
|
|
self.add_address(address) |
|
|
|
with self.lock: |
|
|
|
if not self.sweepstore.has_channel_info(address): |
|
|
|
self.sweepstore.add_channel_info(address, outpoint) |
|
|
|
if not self.sweepstore.has_channel(outpoint): |
|
|
|
self.sweepstore.add_channel(outpoint, address) |
|
|
|
|
|
|
|
def unwatch_channel(self, address, funding_outpoint): |
|
|
|
self.print_error('unwatching', funding_outpoint) |
|
|
|
self.sweepstore.remove_sweep_tx(funding_outpoint) |
|
|
|
self.sweepstore.remove_channel_info(address) |
|
|
|
self.sweepstore.remove_channel_info(funding_outpoint) |
|
|
|
if funding_outpoint in self.tx_progress: |
|
|
|
self.tx_progress[funding_outpoint].all_done.set() |
|
|
|
|
|
|
@ -259,10 +269,10 @@ class LNWatcher(AddressSynchronizer): |
|
|
|
await self.tx_progress[funding_outpoint].tx_queue.put(tx) |
|
|
|
return txid |
|
|
|
|
|
|
|
@with_watchtower |
|
|
|
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict): |
|
|
|
tx = Transaction.from_dict(tx_dict) |
|
|
|
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx: str): |
|
|
|
self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx) |
|
|
|
if self.watchtower: |
|
|
|
self.watchtower_queue.put_nowait(funding_outpoint) |
|
|
|
|
|
|
|
def get_tx_mined_depth(self, txid: str): |
|
|
|
if not txid: |
|
|
|