From 7d54e180c2825170c47d65182510b2ccfae8cdee Mon Sep 17 00:00:00 2001
From: Janus <ysangkok@gmail.com>
Date: Mon, 10 Sep 2018 18:01:55 +0200
Subject: [PATCH] lightning: post aiorpcx rebase fixup

---
 electrum/lnbase.py            | 19 +--------
 electrum/lnchanannverifier.py | 40 +++++++-----------
 electrum/lnrouter.py          |  4 +-
 electrum/lnwatcher.py         | 79 ++++++++++++++++-------------------
 electrum/lnworker.py          | 62 +++++++++++++--------------
 5 files changed, 88 insertions(+), 116 deletions(-)

diff --git a/electrum/lnbase.py b/electrum/lnbase.py
index b05c42651..6c0378653 100644
--- a/electrum/lnbase.py
+++ b/electrum/lnbase.py
@@ -29,7 +29,7 @@ from . import crypto
 from .crypto import sha256
 from . import constants
 from . import transaction
-from .util import PrintError, bh2u, print_error, bfh
+from .util import PrintError, bh2u, print_error, bfh, aiosafe
 from .transaction import opcodes, Transaction, TxOutput
 from .lnonion import new_onion_packet, OnionHopsDataSingle, OnionPerHop, decode_onion_error, ONION_FAILURE_CODE_MAP
 from .lnaddr import lndecode
@@ -266,21 +266,6 @@ def create_ephemeral_key() -> (bytes, bytes):
     return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
 
 
-def aiosafe(f):
-    # save exception in object.
-    # f must be a method of a PrintError instance.
-    # aiosafe calls should not be nested
-    async def f2(*args, **kwargs):
-        self = args[0]
-        try:
-            return await f(*args, **kwargs)
-        except BaseException as e:
-            self.print_error("Exception in", f.__name__, ":", e.__class__.__name__, str(e))
-            self.exception = e
-    return f2
-
-
-
 class Peer(PrintError):
 
     def __init__(self, lnworker, host, port, pubkey, request_initial_sync=False):
@@ -612,7 +597,7 @@ class Peer(PrintError):
         remote_sig = payload['signature']
         m.receive_new_commitment(remote_sig, [])
         # broadcast funding tx
-        success, _txid = self.network.broadcast_transaction(funding_tx)
+        success, _txid = await self.network.broadcast_transaction(funding_tx)
         assert success, success
         m.remote_state = m.remote_state._replace(ctn=0)
         m.local_state = m.local_state._replace(ctn=0, current_commitment_signature=remote_sig)
diff --git a/electrum/lnchanannverifier.py b/electrum/lnchanannverifier.py
index fabdf9b47..fb3a1b7d1 100644
--- a/electrum/lnchanannverifier.py
+++ b/electrum/lnchanannverifier.py
@@ -23,7 +23,9 @@
 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 # SOFTWARE.
 
+import asyncio
 import threading
+from aiorpcx import TaskGroup
 
 from . import lnbase
 from . import bitcoin
@@ -60,7 +62,13 @@ class LNChanAnnVerifier(ThreadJob):
     def get_pending_channel_info(self, short_channel_id):
         return self.unverified_channel_info.get(short_channel_id, None)
 
-    def run(self):
+    async def main(self):
+        while True:
+            async with TaskGroup() as tg:
+                await self.iteration(tg)
+            await asyncio.sleep(0.1)
+
+    async def iteration(self, tg):
         interface = self.network.interface
         if not interface:
             return
@@ -81,40 +89,24 @@ class LNChanAnnVerifier(ThreadJob):
             if header is None:
                 index = block_height // 2016
                 if index < len(blockchain.checkpoints):
-                    self.network.request_chunk(interface, index)
+                    await tg.spawn(self.network.request_chunk(interface, index))
                 continue
-            callback = lambda resp, short_channel_id=short_channel_id: self.on_txid_and_merkle(resp, short_channel_id)
-            self.network.get_txid_from_txpos(block_height, tx_pos, True,
-                                             callback=callback)
+            await tg.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
             #self.print_error('requested short_channel_id', bh2u(short_channel_id))
-            with self.lock:
-                self.started_verifying_channel.add(short_channel_id)
 
-    def on_txid_and_merkle(self, response, short_channel_id):
-        if response.get('error'):
-            self.print_error('received an error:', response)
-            return
-        result = response['result']
+    async def verify_channel(self, block_height, tx_pos, short_channel_id):
+        with self.lock:
+            self.started_verifying_channel.add(short_channel_id)
+        result = await self.network.get_txid_from_txpos(block_height, tx_pos, True)
         tx_hash = result['tx_hash']
         merkle_branch = result['merkle']
-        block_height, tx_pos, output_idx = invert_short_channel_id(short_channel_id)
         header = self.network.blockchain().read_header(block_height)
         try:
             verify_tx_is_in_block(tx_hash, merkle_branch, tx_pos, header, block_height)
         except MerkleVerificationFailure as e:
             self.print_error(str(e))
             return
-        callback = lambda resp, short_channel_id=short_channel_id: self.on_tx_response(resp, short_channel_id)
-        self.network.get_transaction(tx_hash, callback=callback)
-
-    def on_tx_response(self, response, short_channel_id):
-        if response.get('error'):
-            self.print_error('received an error:', response)
-            return
-        params = response['params']
-        result = response['result']
-        tx_hash = params[0]
-        tx = Transaction(result)
+        tx = Transaction(await self.network.get_transaction(tx_hash))
         try:
             tx.deserialize()
         except Exception:
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
index 84d3275a5..e8dc3b8b7 100644
--- a/electrum/lnrouter.py
+++ b/electrum/lnrouter.py
@@ -31,6 +31,7 @@ from collections import namedtuple, defaultdict
 from typing import Sequence, Union, Tuple, Optional
 import binascii
 import base64
+import asyncio
 
 from . import constants
 from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
@@ -277,7 +278,8 @@ class ChannelDB(JsonDB):
         self._last_good_address = {}  # node_id -> LNPeerAddr
 
         self.ca_verifier = LNChanAnnVerifier(network, self)
-        self.network.add_jobs([self.ca_verifier])
+        # FIXME if the channel verifier raises, it kills network.main_taskgroup
+        asyncio.run_coroutine_threadsafe(self.network.add_job(self.ca_verifier.main()), network.asyncio_loop)
 
         self.load_data()
 
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
index ecb87e5b1..ff0158872 100644
--- a/electrum/lnwatcher.py
+++ b/electrum/lnwatcher.py
@@ -1,12 +1,13 @@
 import threading
+import asyncio
 
-from .util import PrintError, bh2u, bfh, NoDynamicFeeEstimates
+from .util import PrintError, bh2u, bfh, NoDynamicFeeEstimates, aiosafe
 from .lnutil import (extract_ctn_from_tx, derive_privkey,
                      get_per_commitment_secret_from_seed, derive_pubkey,
                      make_commitment_output_to_remote_address,
                      RevocationStore, UnableToDeriveSecret)
 from . import lnutil
-from .bitcoin import redeem_script_to_address, TYPE_ADDRESS
+from .bitcoin import redeem_script_to_address, TYPE_ADDRESS, address_to_scripthash
 from . import transaction
 from .transaction import Transaction, TxOutput
 from . import ecc
@@ -22,33 +23,27 @@ class LNWatcher(PrintError):
         self.watched_channels = {}
         self.address_status = {}  # addr -> status
 
-    def parse_response(self, response):
-        if response.get('error'):
-            self.print_error("response error:", response)
-            return None, None
-        return response['params'], response['result']
+    @aiosafe
+    async def handle_addresses(self, funding_address):
+        queue = asyncio.Queue()
+        params = [address_to_scripthash(funding_address)]
+        await self.network.interface.session.subscribe('blockchain.scripthash.subscribe', params, queue)
+        await queue.get()
+        while True:
+            result = await queue.get()
+            await self.on_address_status(funding_address, result)
 
     def watch_channel(self, chan, callback):
         funding_address = chan.get_funding_address()
         self.watched_channels[funding_address] = chan, callback
-        self.network.subscribe_to_addresses([funding_address], self.on_address_status)
+        asyncio.get_event_loop().create_task(self.handle_addresses(funding_address))
 
-    def on_address_status(self, response):
-        params, result = self.parse_response(response)
-        if not params:
-            return
-        addr = params[0]
+    async def on_address_status(self, addr, result):
         if self.address_status.get(addr) != result:
             self.address_status[addr] = result
-            self.network.request_address_utxos(addr, self.on_utxos)
-
-    def on_utxos(self, response):
-        params, result = self.parse_response(response)
-        if not params:
-            return
-        addr = params[0]
-        chan, callback = self.watched_channels[addr]
-        callback(chan, result)
+            result = await self.network.interface.session.send_request('blockchain.scripthash.listunspent', [address_to_scripthash(addr)])
+            chan, callback = self.watched_channels[addr]
+            await callback(chan, result)
 
 
 
@@ -69,9 +64,9 @@ class LNChanCloseHandler(PrintError):
         network.register_callback(self.on_network_update, ['updated'])
         self.watch_address(self.funding_address)
 
-    def on_network_update(self, event, *args):
+    async def on_network_update(self, event, *args):
         if self.wallet.synchronizer.is_up_to_date():
-            self.check_onchain_situation()
+            await self.check_onchain_situation()
 
     def stop_and_delete(self):
         self.network.unregister_callback(self.on_network_update)
@@ -82,7 +77,7 @@ class LNChanCloseHandler(PrintError):
             self.watched_addresses.add(addr)
             self.wallet.synchronizer.add(addr)
 
-    def check_onchain_situation(self):
+    async def check_onchain_situation(self):
         funding_outpoint = self.chan.funding_outpoint
         ctx_candidate_txid = self.wallet.spent_outpoints[funding_outpoint.txid].get(funding_outpoint.output_index)
         if ctx_candidate_txid is None:
@@ -104,13 +99,13 @@ class LNChanCloseHandler(PrintError):
         conf = self.wallet.get_tx_height(ctx_candidate_txid).conf
         if conf == 0:
             return
-        keep_watching_this = self.inspect_ctx_candidate(ctx_candidate, i)
+        keep_watching_this = await self.inspect_ctx_candidate(ctx_candidate, i)
         if not keep_watching_this:
             self.stop_and_delete()
 
     # TODO batch sweeps
     # TODO sweep HTLC outputs
-    def inspect_ctx_candidate(self, ctx, txin_idx: int):
+    async def inspect_ctx_candidate(self, ctx, txin_idx: int):
         """Returns True iff found any not-deeply-spent outputs that we could
         potentially sweep at some point."""
         keep_watching_this = False
@@ -127,7 +122,7 @@ class LNChanCloseHandler(PrintError):
             # note that we might also get here if this is our ctx and the ctn just happens to match
             their_cur_pcp = chan.remote_state.current_per_commitment_point
             if their_cur_pcp is not None:
-                keep_watching_this |= self.find_and_sweep_their_ctx_to_remote(ctx, their_cur_pcp)
+                keep_watching_this |= await self.find_and_sweep_their_ctx_to_remote(ctx, their_cur_pcp)
         # see if we have a revoked secret for this ctn ("breach")
         try:
             per_commitment_secret = chan.remote_state.revocation_store.retrieve_secret(
@@ -138,13 +133,13 @@ class LNChanCloseHandler(PrintError):
             # note that we might also get here if this is our ctx and we just happen to have
             # the secret for the symmetric ctn
             their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True)
-            keep_watching_this |= self.find_and_sweep_their_ctx_to_remote(ctx, their_pcp)
-            keep_watching_this |= self.find_and_sweep_their_ctx_to_local(ctx, per_commitment_secret)
+            keep_watching_this |= await self.find_and_sweep_their_ctx_to_remote(ctx, their_pcp)
+            keep_watching_this |= await self.find_and_sweep_their_ctx_to_local(ctx, per_commitment_secret)
         # see if it's our ctx
         our_per_commitment_secret = get_per_commitment_secret_from_seed(
             chan.local_state.per_commitment_secret_seed, RevocationStore.START_INDEX - ctn)
         our_per_commitment_point = ecc.ECPrivkey(our_per_commitment_secret).get_public_key_bytes(compressed=True)
-        keep_watching_this |= self.find_and_sweep_our_ctx_to_local(ctx, our_per_commitment_point)
+        keep_watching_this |= await self.find_and_sweep_our_ctx_to_local(ctx, our_per_commitment_point)
         return keep_watching_this
 
     def get_tx_mined_status(self, txid):
@@ -166,7 +161,7 @@ class LNChanCloseHandler(PrintError):
         else:
             raise NotImplementedError()
 
-    def find_and_sweep_their_ctx_to_remote(self, ctx, their_pcp: bytes):
+    async def find_and_sweep_their_ctx_to_remote(self, ctx, their_pcp: bytes):
         """Returns True iff found a not-deeply-spent output that we could
         potentially sweep at some point."""
         payment_bp_privkey = ecc.ECPrivkey(self.chan.local_config.payment_basepoint.privkey)
@@ -193,12 +188,12 @@ class LNChanCloseHandler(PrintError):
             return True
         sweep_tx = create_sweeptx_their_ctx_to_remote(self.network, self.sweep_address, ctx,
                                                       output_idx, our_payment_privkey)
-        self.network.broadcast_transaction(sweep_tx,
-                                           lambda res: self.print_tx_broadcast_result('sweep_their_ctx_to_remote', res))
+        res = await self.network.broadcast_transaction(sweep_tx)
+        self.print_tx_broadcast_result('sweep_their_ctx_to_remote', res)
         return True
 
 
-    def find_and_sweep_their_ctx_to_local(self, ctx, per_commitment_secret: bytes):
+    async def find_and_sweep_their_ctx_to_local(self, ctx, per_commitment_secret: bytes):
         """Returns True iff found a not-deeply-spent output that we could
         potentially sweep at some point."""
         per_commitment_point = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True)
@@ -230,11 +225,11 @@ class LNChanCloseHandler(PrintError):
             return True
         sweep_tx = create_sweeptx_ctx_to_local(self.network, self.sweep_address, ctx, output_idx,
                                                witness_script, revocation_privkey, True)
-        self.network.broadcast_transaction(sweep_tx,
-                                           lambda res: self.print_tx_broadcast_result('sweep_their_ctx_to_local', res))
+        res = await self.network.broadcast_transaction(sweep_tx)
+        self.print_tx_broadcast_result('sweep_their_ctx_to_local', res)
         return True
 
-    def find_and_sweep_our_ctx_to_local(self, ctx, our_pcp: bytes):
+    async def find_and_sweep_our_ctx_to_local(self, ctx, our_pcp: bytes):
         """Returns True iff found a not-deeply-spent output that we could
         potentially sweep at some point."""
         delayed_bp_privkey = ecc.ECPrivkey(self.chan.local_config.delayed_basepoint.privkey)
@@ -272,14 +267,14 @@ class LNChanCloseHandler(PrintError):
         sweep_tx = create_sweeptx_ctx_to_local(self.network, self.sweep_address, ctx, output_idx,
                                                witness_script, our_localdelayed_privkey.get_secret_bytes(),
                                                False, to_self_delay)
-        self.network.broadcast_transaction(sweep_tx,
-                                           lambda res: self.print_tx_broadcast_result('sweep_our_ctx_to_local', res))
+        res = await self.network.broadcast_transaction(sweep_tx)
+        self.print_tx_broadcast_result('sweep_our_ctx_to_local', res)
         return True
 
     def print_tx_broadcast_result(self, name, res):
-        error = res.get('error')
+        error, msg = res
         if error:
-            self.print_error('{} broadcast failed: {}'.format(name, error))
+            self.print_error('{} broadcast failed: {}'.format(name, msg))
         else:
             self.print_error('{} broadcast succeeded'.format(name))
 
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
index 7b88c4295..b70f2f83f 100644
--- a/electrum/lnworker.py
+++ b/electrum/lnworker.py
@@ -11,8 +11,8 @@ import dns.exception
 
 from . import constants
 from .bitcoin import sha256, COIN
-from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv
-from .lnbase import Peer, privkey_to_pubkey, aiosafe
+from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, aiosafe
+from .lnbase import Peer, privkey_to_pubkey
 from .lnaddr import lnencode, LnAddr, lndecode
 from .ecc import der_sig_from_sig_string
 from .lnhtlc import HTLCStateMachine
@@ -55,8 +55,7 @@ class LNWorker(PrintError):
         self._add_peers_from_config()
         # wait until we see confirmations
         self.network.register_callback(self.on_network_update, ['updated', 'verified', 'fee']) # thread safe
-        self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified
-        self.network.futures.append(asyncio.run_coroutine_threadsafe(self.main_loop(), asyncio.get_event_loop()))
+        asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
 
     def _add_peers_from_config(self):
         peer_list = self.config.get('lightning_peers', [])
@@ -84,7 +83,7 @@ class LNWorker(PrintError):
         self._last_tried_peer[peer_addr] = time.time()
         self.print_error("adding peer", peer_addr)
         peer = Peer(self, host, port, node_id, request_initial_sync=self.config.get("request_initial_sync", True))
-        self.network.futures.append(asyncio.run_coroutine_threadsafe(peer.main_loop(), asyncio.get_event_loop()))
+        asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(peer.main_loop()), self.network.asyncio_loop)
         self.peers[node_id] = peer
         self.network.trigger_callback('ln_status')
 
@@ -119,7 +118,7 @@ class LNWorker(PrintError):
             return True
         return False
 
-    def on_channel_utxos(self, chan, utxos):
+    async def on_channel_utxos(self, chan, utxos):
         outpoints = [Outpoint(x["tx_hash"], x["tx_pos"]) for x in utxos]
         if chan.funding_outpoint not in outpoints:
             chan.set_funding_txo_spentness(True)
@@ -131,33 +130,32 @@ class LNWorker(PrintError):
             chan.set_funding_txo_spentness(False)
         self.network.trigger_callback('channel', chan)
 
-    def on_network_update(self, event, *args):
-        """ called from network thread """
+    @aiosafe
+    async def on_network_update(self, event, *args):
+        # TODO
         # Race discovered in save_channel (assertion failing):
         # since short_channel_id could be changed while saving.
-        # Mitigated by posting to loop:
-        async def network_jobs():
-            with self.lock:
-                channels = list(self.channels.values())
-            for chan in channels:
-                if chan.get_state() == "OPENING":
-                    res = self.save_short_chan_id(chan)
-                    if not res:
-                        self.print_error("network update but funding tx is still not at sufficient depth")
-                        continue
-                    # this results in the channel being marked OPEN
-                    peer = self.peers[chan.node_id]
-                    peer.funding_locked(chan)
-                elif chan.get_state() == "OPEN":
-                    peer = self.peers.get(chan.node_id)
-                    if peer is None:
-                        self.print_error("peer not found for {}".format(bh2u(chan.node_id)))
-                        return
-                    if event == 'fee':
-                        peer.on_bitcoin_fee_update(chan)
-                    conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf
-                    peer.on_network_update(chan, conf)
-        asyncio.run_coroutine_threadsafe(network_jobs(), self.network.asyncio_loop).result()
+        with self.lock:
+            channels = list(self.channels.values())
+        for chan in channels:
+            print("update", chan.get_state())
+            if chan.get_state() == "OPENING":
+                res = self.save_short_chan_id(chan)
+                if not res:
+                    self.print_error("network update but funding tx is still not at sufficient depth")
+                    continue
+                # this results in the channel being marked OPEN
+                peer = self.peers[chan.node_id]
+                peer.funding_locked(chan)
+            elif chan.get_state() == "OPEN":
+                peer = self.peers.get(chan.node_id)
+                if peer is None:
+                    self.print_error("peer not found for {}".format(bh2u(chan.node_id)))
+                    return
+                if event == 'fee':
+                    peer.on_bitcoin_fee_update(chan)
+                conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf
+                peer.on_network_update(chan, conf)
 
     async def _open_channel_coroutine(self, node_id, local_amount_sat, push_sat, password):
         peer = self.peers[node_id]
@@ -345,8 +343,8 @@ class LNWorker(PrintError):
                 coro = peer.reestablish_channel(chan)
                 asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
 
-    @aiosafe
     async def main_loop(self):
+        await self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified
         while True:
             await asyncio.sleep(1)
             now = time.time()