From 3eabd70df59dc43e6f627198a6530ffe90f26766 Mon Sep 17 00:00:00 2001 From: Janus Date: Mon, 10 Sep 2018 18:01:55 +0200 Subject: [PATCH] lightning: post aiorpcx rebase fixup --- electrum/lnbase.py | 19 +-------- electrum/lnchanannverifier.py | 40 +++++++----------- electrum/lnrouter.py | 4 +- electrum/lnwatcher.py | 79 ++++++++++++++++------------------- electrum/lnworker.py | 62 +++++++++++++-------------- 5 files changed, 88 insertions(+), 116 deletions(-) diff --git a/electrum/lnbase.py b/electrum/lnbase.py index b05c42651..6c0378653 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -29,7 +29,7 @@ from . import crypto from .crypto import sha256 from . import constants from . import transaction -from .util import PrintError, bh2u, print_error, bfh +from .util import PrintError, bh2u, print_error, bfh, aiosafe from .transaction import opcodes, Transaction, TxOutput from .lnonion import new_onion_packet, OnionHopsDataSingle, OnionPerHop, decode_onion_error, ONION_FAILURE_CODE_MAP from .lnaddr import lndecode @@ -266,21 +266,6 @@ def create_ephemeral_key() -> (bytes, bytes): return privkey.get_secret_bytes(), privkey.get_public_key_bytes() -def aiosafe(f): - # save exception in object. - # f must be a method of a PrintError instance. - # aiosafe calls should not be nested - async def f2(*args, **kwargs): - self = args[0] - try: - return await f(*args, **kwargs) - except BaseException as e: - self.print_error("Exception in", f.__name__, ":", e.__class__.__name__, str(e)) - self.exception = e - return f2 - - - class Peer(PrintError): def __init__(self, lnworker, host, port, pubkey, request_initial_sync=False): @@ -612,7 +597,7 @@ class Peer(PrintError): remote_sig = payload['signature'] m.receive_new_commitment(remote_sig, []) # broadcast funding tx - success, _txid = self.network.broadcast_transaction(funding_tx) + success, _txid = await self.network.broadcast_transaction(funding_tx) assert success, success m.remote_state = m.remote_state._replace(ctn=0) m.local_state = m.local_state._replace(ctn=0, current_commitment_signature=remote_sig) diff --git a/electrum/lnchanannverifier.py b/electrum/lnchanannverifier.py index fabdf9b47..fb3a1b7d1 100644 --- a/electrum/lnchanannverifier.py +++ b/electrum/lnchanannverifier.py @@ -23,7 +23,9 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio import threading +from aiorpcx import TaskGroup from . import lnbase from . import bitcoin @@ -60,7 +62,13 @@ class LNChanAnnVerifier(ThreadJob): def get_pending_channel_info(self, short_channel_id): return self.unverified_channel_info.get(short_channel_id, None) - def run(self): + async def main(self): + while True: + async with TaskGroup() as tg: + await self.iteration(tg) + await asyncio.sleep(0.1) + + async def iteration(self, tg): interface = self.network.interface if not interface: return @@ -81,40 +89,24 @@ class LNChanAnnVerifier(ThreadJob): if header is None: index = block_height // 2016 if index < len(blockchain.checkpoints): - self.network.request_chunk(interface, index) + await tg.spawn(self.network.request_chunk(interface, index)) continue - callback = lambda resp, short_channel_id=short_channel_id: self.on_txid_and_merkle(resp, short_channel_id) - self.network.get_txid_from_txpos(block_height, tx_pos, True, - callback=callback) + await tg.spawn(self.verify_channel(block_height, tx_pos, short_channel_id)) #self.print_error('requested short_channel_id', bh2u(short_channel_id)) - with self.lock: - self.started_verifying_channel.add(short_channel_id) - def on_txid_and_merkle(self, response, short_channel_id): - if response.get('error'): - self.print_error('received an error:', response) - return - result = response['result'] + async def verify_channel(self, block_height, tx_pos, short_channel_id): + with self.lock: + self.started_verifying_channel.add(short_channel_id) + result = await self.network.get_txid_from_txpos(block_height, tx_pos, True) tx_hash = result['tx_hash'] merkle_branch = result['merkle'] - block_height, tx_pos, output_idx = invert_short_channel_id(short_channel_id) header = self.network.blockchain().read_header(block_height) try: verify_tx_is_in_block(tx_hash, merkle_branch, tx_pos, header, block_height) except MerkleVerificationFailure as e: self.print_error(str(e)) return - callback = lambda resp, short_channel_id=short_channel_id: self.on_tx_response(resp, short_channel_id) - self.network.get_transaction(tx_hash, callback=callback) - - def on_tx_response(self, response, short_channel_id): - if response.get('error'): - self.print_error('received an error:', response) - return - params = response['params'] - result = response['result'] - tx_hash = params[0] - tx = Transaction(result) + tx = Transaction(await self.network.get_transaction(tx_hash)) try: tx.deserialize() except Exception: diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 84d3275a5..e8dc3b8b7 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -31,6 +31,7 @@ from collections import namedtuple, defaultdict from typing import Sequence, Union, Tuple, Optional import binascii import base64 +import asyncio from . import constants from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits @@ -277,7 +278,8 @@ class ChannelDB(JsonDB): self._last_good_address = {} # node_id -> LNPeerAddr self.ca_verifier = LNChanAnnVerifier(network, self) - self.network.add_jobs([self.ca_verifier]) + # FIXME if the channel verifier raises, it kills network.main_taskgroup + asyncio.run_coroutine_threadsafe(self.network.add_job(self.ca_verifier.main()), network.asyncio_loop) self.load_data() diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index ecb87e5b1..ff0158872 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -1,12 +1,13 @@ import threading +import asyncio -from .util import PrintError, bh2u, bfh, NoDynamicFeeEstimates +from .util import PrintError, bh2u, bfh, NoDynamicFeeEstimates, aiosafe from .lnutil import (extract_ctn_from_tx, derive_privkey, get_per_commitment_secret_from_seed, derive_pubkey, make_commitment_output_to_remote_address, RevocationStore, UnableToDeriveSecret) from . import lnutil -from .bitcoin import redeem_script_to_address, TYPE_ADDRESS +from .bitcoin import redeem_script_to_address, TYPE_ADDRESS, address_to_scripthash from . import transaction from .transaction import Transaction, TxOutput from . import ecc @@ -22,33 +23,27 @@ class LNWatcher(PrintError): self.watched_channels = {} self.address_status = {} # addr -> status - def parse_response(self, response): - if response.get('error'): - self.print_error("response error:", response) - return None, None - return response['params'], response['result'] + @aiosafe + async def handle_addresses(self, funding_address): + queue = asyncio.Queue() + params = [address_to_scripthash(funding_address)] + await self.network.interface.session.subscribe('blockchain.scripthash.subscribe', params, queue) + await queue.get() + while True: + result = await queue.get() + await self.on_address_status(funding_address, result) def watch_channel(self, chan, callback): funding_address = chan.get_funding_address() self.watched_channels[funding_address] = chan, callback - self.network.subscribe_to_addresses([funding_address], self.on_address_status) + asyncio.get_event_loop().create_task(self.handle_addresses(funding_address)) - def on_address_status(self, response): - params, result = self.parse_response(response) - if not params: - return - addr = params[0] + async def on_address_status(self, addr, result): if self.address_status.get(addr) != result: self.address_status[addr] = result - self.network.request_address_utxos(addr, self.on_utxos) - - def on_utxos(self, response): - params, result = self.parse_response(response) - if not params: - return - addr = params[0] - chan, callback = self.watched_channels[addr] - callback(chan, result) + result = await self.network.interface.session.send_request('blockchain.scripthash.listunspent', [address_to_scripthash(addr)]) + chan, callback = self.watched_channels[addr] + await callback(chan, result) @@ -69,9 +64,9 @@ class LNChanCloseHandler(PrintError): network.register_callback(self.on_network_update, ['updated']) self.watch_address(self.funding_address) - def on_network_update(self, event, *args): + async def on_network_update(self, event, *args): if self.wallet.synchronizer.is_up_to_date(): - self.check_onchain_situation() + await self.check_onchain_situation() def stop_and_delete(self): self.network.unregister_callback(self.on_network_update) @@ -82,7 +77,7 @@ class LNChanCloseHandler(PrintError): self.watched_addresses.add(addr) self.wallet.synchronizer.add(addr) - def check_onchain_situation(self): + async def check_onchain_situation(self): funding_outpoint = self.chan.funding_outpoint ctx_candidate_txid = self.wallet.spent_outpoints[funding_outpoint.txid].get(funding_outpoint.output_index) if ctx_candidate_txid is None: @@ -104,13 +99,13 @@ class LNChanCloseHandler(PrintError): conf = self.wallet.get_tx_height(ctx_candidate_txid).conf if conf == 0: return - keep_watching_this = self.inspect_ctx_candidate(ctx_candidate, i) + keep_watching_this = await self.inspect_ctx_candidate(ctx_candidate, i) if not keep_watching_this: self.stop_and_delete() # TODO batch sweeps # TODO sweep HTLC outputs - def inspect_ctx_candidate(self, ctx, txin_idx: int): + async def inspect_ctx_candidate(self, ctx, txin_idx: int): """Returns True iff found any not-deeply-spent outputs that we could potentially sweep at some point.""" keep_watching_this = False @@ -127,7 +122,7 @@ class LNChanCloseHandler(PrintError): # note that we might also get here if this is our ctx and the ctn just happens to match their_cur_pcp = chan.remote_state.current_per_commitment_point if their_cur_pcp is not None: - keep_watching_this |= self.find_and_sweep_their_ctx_to_remote(ctx, their_cur_pcp) + keep_watching_this |= await self.find_and_sweep_their_ctx_to_remote(ctx, their_cur_pcp) # see if we have a revoked secret for this ctn ("breach") try: per_commitment_secret = chan.remote_state.revocation_store.retrieve_secret( @@ -138,13 +133,13 @@ class LNChanCloseHandler(PrintError): # note that we might also get here if this is our ctx and we just happen to have # the secret for the symmetric ctn their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True) - keep_watching_this |= self.find_and_sweep_their_ctx_to_remote(ctx, their_pcp) - keep_watching_this |= self.find_and_sweep_their_ctx_to_local(ctx, per_commitment_secret) + keep_watching_this |= await self.find_and_sweep_their_ctx_to_remote(ctx, their_pcp) + keep_watching_this |= await self.find_and_sweep_their_ctx_to_local(ctx, per_commitment_secret) # see if it's our ctx our_per_commitment_secret = get_per_commitment_secret_from_seed( chan.local_state.per_commitment_secret_seed, RevocationStore.START_INDEX - ctn) our_per_commitment_point = ecc.ECPrivkey(our_per_commitment_secret).get_public_key_bytes(compressed=True) - keep_watching_this |= self.find_and_sweep_our_ctx_to_local(ctx, our_per_commitment_point) + keep_watching_this |= await self.find_and_sweep_our_ctx_to_local(ctx, our_per_commitment_point) return keep_watching_this def get_tx_mined_status(self, txid): @@ -166,7 +161,7 @@ class LNChanCloseHandler(PrintError): else: raise NotImplementedError() - def find_and_sweep_their_ctx_to_remote(self, ctx, their_pcp: bytes): + async def find_and_sweep_their_ctx_to_remote(self, ctx, their_pcp: bytes): """Returns True iff found a not-deeply-spent output that we could potentially sweep at some point.""" payment_bp_privkey = ecc.ECPrivkey(self.chan.local_config.payment_basepoint.privkey) @@ -193,12 +188,12 @@ class LNChanCloseHandler(PrintError): return True sweep_tx = create_sweeptx_their_ctx_to_remote(self.network, self.sweep_address, ctx, output_idx, our_payment_privkey) - self.network.broadcast_transaction(sweep_tx, - lambda res: self.print_tx_broadcast_result('sweep_their_ctx_to_remote', res)) + res = await self.network.broadcast_transaction(sweep_tx) + self.print_tx_broadcast_result('sweep_their_ctx_to_remote', res) return True - def find_and_sweep_their_ctx_to_local(self, ctx, per_commitment_secret: bytes): + async def find_and_sweep_their_ctx_to_local(self, ctx, per_commitment_secret: bytes): """Returns True iff found a not-deeply-spent output that we could potentially sweep at some point.""" per_commitment_point = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True) @@ -230,11 +225,11 @@ class LNChanCloseHandler(PrintError): return True sweep_tx = create_sweeptx_ctx_to_local(self.network, self.sweep_address, ctx, output_idx, witness_script, revocation_privkey, True) - self.network.broadcast_transaction(sweep_tx, - lambda res: self.print_tx_broadcast_result('sweep_their_ctx_to_local', res)) + res = await self.network.broadcast_transaction(sweep_tx) + self.print_tx_broadcast_result('sweep_their_ctx_to_local', res) return True - def find_and_sweep_our_ctx_to_local(self, ctx, our_pcp: bytes): + async def find_and_sweep_our_ctx_to_local(self, ctx, our_pcp: bytes): """Returns True iff found a not-deeply-spent output that we could potentially sweep at some point.""" delayed_bp_privkey = ecc.ECPrivkey(self.chan.local_config.delayed_basepoint.privkey) @@ -272,14 +267,14 @@ class LNChanCloseHandler(PrintError): sweep_tx = create_sweeptx_ctx_to_local(self.network, self.sweep_address, ctx, output_idx, witness_script, our_localdelayed_privkey.get_secret_bytes(), False, to_self_delay) - self.network.broadcast_transaction(sweep_tx, - lambda res: self.print_tx_broadcast_result('sweep_our_ctx_to_local', res)) + res = await self.network.broadcast_transaction(sweep_tx) + self.print_tx_broadcast_result('sweep_our_ctx_to_local', res) return True def print_tx_broadcast_result(self, name, res): - error = res.get('error') + error, msg = res if error: - self.print_error('{} broadcast failed: {}'.format(name, error)) + self.print_error('{} broadcast failed: {}'.format(name, msg)) else: self.print_error('{} broadcast succeeded'.format(name)) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 7b88c4295..b70f2f83f 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -11,8 +11,8 @@ import dns.exception from . import constants from .bitcoin import sha256, COIN -from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv -from .lnbase import Peer, privkey_to_pubkey, aiosafe +from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, aiosafe +from .lnbase import Peer, privkey_to_pubkey from .lnaddr import lnencode, LnAddr, lndecode from .ecc import der_sig_from_sig_string from .lnhtlc import HTLCStateMachine @@ -55,8 +55,7 @@ class LNWorker(PrintError): self._add_peers_from_config() # wait until we see confirmations self.network.register_callback(self.on_network_update, ['updated', 'verified', 'fee']) # thread safe - self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified - self.network.futures.append(asyncio.run_coroutine_threadsafe(self.main_loop(), asyncio.get_event_loop())) + asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop) def _add_peers_from_config(self): peer_list = self.config.get('lightning_peers', []) @@ -84,7 +83,7 @@ class LNWorker(PrintError): self._last_tried_peer[peer_addr] = time.time() self.print_error("adding peer", peer_addr) peer = Peer(self, host, port, node_id, request_initial_sync=self.config.get("request_initial_sync", True)) - self.network.futures.append(asyncio.run_coroutine_threadsafe(peer.main_loop(), asyncio.get_event_loop())) + asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(peer.main_loop()), self.network.asyncio_loop) self.peers[node_id] = peer self.network.trigger_callback('ln_status') @@ -119,7 +118,7 @@ class LNWorker(PrintError): return True return False - def on_channel_utxos(self, chan, utxos): + async def on_channel_utxos(self, chan, utxos): outpoints = [Outpoint(x["tx_hash"], x["tx_pos"]) for x in utxos] if chan.funding_outpoint not in outpoints: chan.set_funding_txo_spentness(True) @@ -131,33 +130,32 @@ class LNWorker(PrintError): chan.set_funding_txo_spentness(False) self.network.trigger_callback('channel', chan) - def on_network_update(self, event, *args): - """ called from network thread """ + @aiosafe + async def on_network_update(self, event, *args): + # TODO # Race discovered in save_channel (assertion failing): # since short_channel_id could be changed while saving. - # Mitigated by posting to loop: - async def network_jobs(): - with self.lock: - channels = list(self.channels.values()) - for chan in channels: - if chan.get_state() == "OPENING": - res = self.save_short_chan_id(chan) - if not res: - self.print_error("network update but funding tx is still not at sufficient depth") - continue - # this results in the channel being marked OPEN - peer = self.peers[chan.node_id] - peer.funding_locked(chan) - elif chan.get_state() == "OPEN": - peer = self.peers.get(chan.node_id) - if peer is None: - self.print_error("peer not found for {}".format(bh2u(chan.node_id))) - return - if event == 'fee': - peer.on_bitcoin_fee_update(chan) - conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf - peer.on_network_update(chan, conf) - asyncio.run_coroutine_threadsafe(network_jobs(), self.network.asyncio_loop).result() + with self.lock: + channels = list(self.channels.values()) + for chan in channels: + print("update", chan.get_state()) + if chan.get_state() == "OPENING": + res = self.save_short_chan_id(chan) + if not res: + self.print_error("network update but funding tx is still not at sufficient depth") + continue + # this results in the channel being marked OPEN + peer = self.peers[chan.node_id] + peer.funding_locked(chan) + elif chan.get_state() == "OPEN": + peer = self.peers.get(chan.node_id) + if peer is None: + self.print_error("peer not found for {}".format(bh2u(chan.node_id))) + return + if event == 'fee': + peer.on_bitcoin_fee_update(chan) + conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf + peer.on_network_update(chan, conf) async def _open_channel_coroutine(self, node_id, local_amount_sat, push_sat, password): peer = self.peers[node_id] @@ -345,8 +343,8 @@ class LNWorker(PrintError): coro = peer.reestablish_channel(chan) asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) - @aiosafe async def main_loop(self): + await self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified while True: await asyncio.sleep(1) now = time.time()