Browse Source

LNWatcher refactoring:

- do not store non-breach transactions
 - send 'channel_open' and 'channel_closed' events
 - force-closed channels are handled by LNWorker
dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
729ddb8ec3
  1. 22
      electrum/lnchan.py
  2. 141
      electrum/lnwatcher.py
  3. 76
      electrum/lnworker.py

22
electrum/lnchan.py

@ -43,8 +43,7 @@ from .lnutil import HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT
from .lnutil import funding_output_script, LOCAL, REMOTE, HTLCOwner, make_closing_tx, make_commitment_outputs
from .lnutil import ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script
from .transaction import Transaction
from .lnsweep import (create_sweeptxs_for_our_latest_ctx, create_sweeptxs_for_their_latest_ctx,
create_sweeptxs_for_their_just_revoked_ctx)
from .lnsweep import create_sweeptxs_for_their_just_revoked_ctx
class ChannelJsonEncoder(json.JSONEncoder):
@ -204,10 +203,10 @@ class Channel(PrintError):
for sub in (LOCAL, REMOTE):
self.log[sub].locked_in.update(self.log[sub].adds.keys())
# used in lnworker.on_channel_closed
self.local_commitment = self.current_commitment(LOCAL)
self.remote_commitment = self.current_commitment(REMOTE)
def set_state(self, state: str):
if self._state == 'FORCE_CLOSING':
assert state == 'FORCE_CLOSING', 'new state was not FORCE_CLOSING: ' + state
@ -325,7 +324,7 @@ class Channel(PrintError):
htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs]
self.process_new_offchain_ctx(pending_remote_commitment, ours=False)
self.remote_commitment = self.pending_commitment(REMOTE)
# we can't know if this message arrives.
# since we shouldn't actually throw away
@ -390,7 +389,7 @@ class Channel(PrintError):
if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]:
self.pending_fee[FUNDER_SIGNED] = True
self.process_new_offchain_ctx(pending_local_commitment, ours=True)
self.local_commitment = self.pending_commitment(LOCAL)
def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool) -> int:
_, this_point, _ = self.points()
@ -454,19 +453,6 @@ class Channel(PrintError):
next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big'))
return last_secret, this_point, next_point
# TODO don't presign txns for non-breach close
def process_new_offchain_ctx(self, ctx: 'Transaction', ours: bool):
if not self.lnwatcher:
return
outpoint = self.funding_outpoint.to_str()
if ours:
encumbered_sweeptxs = create_sweeptxs_for_our_latest_ctx(self, ctx, self.sweep_address)
else:
encumbered_sweeptxs = create_sweeptxs_for_their_latest_ctx(self, ctx, self.sweep_address)
for prev_txid, encumbered_tx in encumbered_sweeptxs:
if encumbered_tx is not None:
self.lnwatcher.add_sweep_tx(outpoint, prev_txid, encumbered_tx.to_json())
def process_new_revocation_secret(self, per_commitment_secret: bytes):
if not self.lnwatcher:
return

141
electrum/lnwatcher.py

@ -112,6 +112,15 @@ class LNWatcher(AddressSynchronizer):
self.channel_info[address] = outpoint
self.write_to_disk()
def unwatch_channel(self, address, funding_outpoint):
self.print_error('unwatching', funding_outpoint)
with self.lock:
self.channel_info.pop(address)
self.sweepstore.pop(funding_outpoint)
self.write_to_disk()
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
@log_exceptions
async def on_network_update(self, event, *args):
if event in ('verified', 'wallet_updated'):
@ -125,90 +134,54 @@ class LNWatcher(AddressSynchronizer):
with self.lock:
channel_info_items = list(self.channel_info.items())
for address, outpoint in channel_info_items:
await self.check_onchain_situation(outpoint)
await self.check_onchain_situation(address, outpoint)
async def check_onchain_situation(self, funding_outpoint):
txid, index = funding_outpoint.split(':')
ctx_candidate_txid = self.spent_outpoints[txid].get(int(index))
is_spent = ctx_candidate_txid is not None
self.network.trigger_callback('channel_txo', funding_outpoint, is_spent)
if not is_spent:
return
ctx_candidate = self.transactions.get(ctx_candidate_txid)
if ctx_candidate is None:
return
#self.print_error("funding outpoint {} is spent by {}"
# .format(funding_outpoint, ctx_candidate_txid))
conf = self.get_tx_height(ctx_candidate_txid).conf
# only care about confirmed and verified ctxs. TODO is this necessary?
if conf == 0:
return
keep_watching_this = await self.inspect_tx_candidate(funding_outpoint, ctx_candidate)
if not keep_watching_this:
self.stop_and_delete(funding_outpoint)
def stop_and_delete(self, funding_outpoint):
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
# TODO delete channel from watcher_db
async def inspect_tx_candidate(self, funding_outpoint, prev_tx):
"""Returns True iff found any not-deeply-spent outputs that we could
potentially sweep at some point."""
# make sure we are subscribed to all outputs of tx
not_yet_watching = False
for o in prev_tx.outputs():
async def check_onchain_situation(self, address, funding_outpoint):
keep_watching, spenders = self.inspect_tx_candidate(funding_outpoint, 0)
txid = spenders.get(funding_outpoint)
if txid is None:
self.network.trigger_callback('channel_open', funding_outpoint)
else:
self.network.trigger_callback('channel_closed', funding_outpoint, txid, spenders)
await self.do_breach_remedy(funding_outpoint, spenders)
if not keep_watching:
self.unwatch_channel(address, funding_outpoint)
else:
self.print_error('we will keep_watching', funding_outpoint)
def inspect_tx_candidate(self, outpoint, n):
# FIXME: instead of stopping recursion at n == 2,
# we should detect which outputs are HTLCs
prev_txid, index = outpoint.split(':')
txid = self.spent_outpoints[prev_txid].get(int(index))
result = {outpoint:txid}
if txid is None:
self.print_error('keep watching because outpoint is unspent')
return True, result
keep_watching = (self.get_tx_mined_depth(txid) != TxMinedDepth.DEEP)
if keep_watching:
self.print_error('keep watching because spending tx is not deep')
tx = self.transactions[txid]
for i, o in enumerate(tx.outputs()):
if o.address not in self.get_addresses():
self.add_address(o.address)
not_yet_watching = True
if not_yet_watching:
self.print_error('prev_tx', prev_tx, 'not yet watching')
return True
# get all possible responses we have
prev_txid = prev_tx.txid()
with self.lock:
encumbered_sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
if len(encumbered_sweep_txns) == 0:
if self.get_tx_mined_depth(prev_txid) == TxMinedDepth.DEEP:
self.print_error('have no follow-up transactions and prevtx', prev_txid, 'mined deep, returning')
return False
return True
# check if any response applies
keep_watching_this = False
local_height = self.network.get_local_height()
self.print_error(funding_outpoint, 'iterating over encumbered txs')
for e_tx in list(encumbered_sweep_txns):
conflicts = self.get_conflicting_transactions(e_tx.tx.txid(), e_tx.tx, include_self=True)
conflict_mined_depth = self.get_deepest_tx_mined_depth_for_txids(conflicts)
if conflict_mined_depth != TxMinedDepth.DEEP:
keep_watching_this = True
if conflict_mined_depth == TxMinedDepth.FREE:
tx_height = self.get_tx_height(prev_txid).height
if tx_height == TX_HEIGHT_LOCAL:
continue
num_conf = local_height - tx_height + 1
broadcast = True
if e_tx.cltv_expiry:
if local_height > e_tx.cltv_expiry:
self.print_error(e_tx.name, 'CLTV ({} > {}) fulfilled'.format(local_height, e_tx.cltv_expiry))
else:
self.print_error(e_tx.name, 'waiting for {}: CLTV ({} > {}), funding outpoint {} and tx {}'
.format(e_tx.name, local_height, e_tx.cltv_expiry, funding_outpoint[:8], prev_tx.txid()[:8]))
broadcast = False
if e_tx.csv_delay:
if num_conf < e_tx.csv_delay:
self.print_error(e_tx.name, 'waiting for {}: CSV ({} >= {}), funding outpoint {} and tx {}'
.format(e_tx.name, num_conf, e_tx.csv_delay, funding_outpoint[:8], prev_tx.txid()[:8]))
broadcast = False
if broadcast:
if not await self.broadcast_or_log(funding_outpoint, e_tx):
self.print_error(e_tx.name, f'could not publish encumbered tx: {str(e_tx)}, prev_txid: {prev_txid}, prev_tx height:', tx_height, 'local_height', local_height)
else:
self.print_error(e_tx.name, 'status', conflict_mined_depth, 'recursing...')
# mined or in mempool
keep_watching_this |= await self.inspect_tx_candidate(funding_outpoint, e_tx.tx)
return keep_watching_this
keep_watching = True
elif n < 2:
k, r = self.inspect_tx_candidate(txid+':%d'%i, n+1)
keep_watching |= k
result.update(r)
return keep_watching, result
async def do_breach_remedy(self, funding_outpoint, spenders):
for prevout, spender in spenders.items():
if spender is not None:
continue
prev_txid, prev_n = prevout.split(':')
with self.lock:
encumbered_sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
for prev_txid, e_tx in encumbered_sweep_txns:
if not await self.broadcast_or_log(funding_outpoint, e_tx):
self.print_error(e_tx.name, f'could not publish encumbered tx: {str(e_tx)}, prev_txid: {prev_txid}')
async def broadcast_or_log(self, funding_outpoint, e_tx):
height = self.get_tx_height(e_tx.tx.txid()).height
@ -249,9 +222,3 @@ class LNWatcher(AddressSynchronizer):
return TxMinedDepth.MEMPOOL
else:
raise NotImplementedError()
def get_deepest_tx_mined_depth_for_txids(self, set_of_txids: Iterable[str]):
if not set_of_txids:
return TxMinedDepth.FREE
# note: using "min" as lower status values are deeper
return min(map(self.get_tx_mined_depth, set_of_txids))

76
electrum/lnworker.py

@ -38,6 +38,7 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use
from .address_synchronizer import TX_HEIGHT_LOCAL
from .lnsweep import create_sweeptxs_for_our_latest_ctx, create_sweeptxs_for_their_latest_ctx
if TYPE_CHECKING:
from .network import Network
@ -88,7 +89,8 @@ class LNWorker(PrintError):
self._add_peers_from_config()
# wait until we see confirmations
self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe
self.network.register_callback(self.on_channel_txo, ['channel_txo'])
self.network.register_callback(self.on_channel_open, ['channel_open'])
self.network.register_callback(self.on_channel_closed, ['channel_closed'])
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
self.first_timestamp_requested = None
@ -282,22 +284,76 @@ class LNWorker(PrintError):
return True, conf
return False, conf
def on_channel_txo(self, event, txo, is_spent: bool):
def channel_by_txo(self, txo):
with self.lock:
channels = list(self.channels.values())
for chan in channels:
if chan.funding_outpoint.to_str() == txo:
break
else:
return chan
def on_channel_open(self, event, funding_outpoint):
chan = self.channel_by_txo(funding_outpoint)
if not chan:
return
chan.set_funding_txo_spentness(is_spent)
if is_spent:
if chan.get_state() != 'FORCE_CLOSING':
chan.set_state("CLOSED")
self.on_channels_updated()
self.channel_db.remove_channel(chan.short_channel_id)
self.print_error('on_channel_open', funding_outpoint)
chan.set_funding_txo_spentness(False)
# send event to GUI
self.network.trigger_callback('channel', chan)
@log_exceptions
async def on_channel_closed(self, event, funding_outpoint, txid, spenders):
chan = self.channel_by_txo(funding_outpoint)
if not chan:
return
self.print_error('on_channel_closed', funding_outpoint)
chan.set_funding_txo_spentness(True)
if chan.get_state() != 'FORCE_CLOSING':
chan.set_state("CLOSED")
self.on_channels_updated()
self.network.trigger_callback('channel', chan)
# remove from channel_db
self.channel_db.remove_channel(chan.short_channel_id)
# sweep
our_ctx = chan.local_commitment
their_ctx = chan.remote_commitment
if txid == our_ctx.txid():
self.print_error('we force closed', funding_outpoint)
# we force closed
encumbered_sweeptxs = create_sweeptxs_for_our_latest_ctx(chan, our_ctx, chan.sweep_address)
elif txid == their_ctx.txid():
self.print_error('they force closed', funding_outpoint)
# they force closed
encumbered_sweeptxs = create_sweeptxs_for_their_latest_ctx(chan, their_ctx, chan.sweep_address)
else:
# cooperative close or breach
self.print_error('not sure who closed', funding_outpoint)
encumbered_sweeptxs = []
local_height = self.network.get_local_height()
for prev_txid, e_tx in encumbered_sweeptxs:
spender = spenders.get(prev_txid + ':0') # we assume output index is 0
if spender is not None:
self.print_error('prev_tx already spent', prev_txid)
continue
num_conf = self.network.lnwatcher.get_tx_height(prev_txid).conf
broadcast = True
if e_tx.cltv_expiry:
if local_height > e_tx.cltv_expiry:
self.print_error(e_tx.name, 'CLTV ({} > {}) fulfilled'.format(local_height, e_tx.cltv_expiry))
else:
self.print_error(e_tx.name, 'waiting for {}: CLTV ({} > {}), funding outpoint {} and tx {}'
.format(e_tx.name, local_height, e_tx.cltv_expiry, funding_outpoint[:8], prev_txid[:8]))
broadcast = False
if e_tx.csv_delay:
if num_conf < e_tx.csv_delay:
self.print_error(e_tx.name, 'waiting for {}: CSV ({} >= {}), funding outpoint {} and tx {}'
.format(e_tx.name, num_conf, e_tx.csv_delay, funding_outpoint[:8], prev_txid[:8]))
broadcast = False
if broadcast:
if not await self.network.lnwatcher.broadcast_or_log(funding_outpoint, e_tx):
self.print_error(e_tx.name, f'could not publish encumbered tx: {str(e_tx)}, prev_txid: {prev_txid}, local_height', local_height)
@log_exceptions
async def on_network_update(self, event, *args):
# TODO

Loading…
Cancel
Save