Browse Source

requirements: bump min aiorpcx to 0.22.0

aiorpcx 0.20 changed the behaviour/API of TaskGroups.
When used as a context manager, TaskGroups no longer propagate
exceptions raised by their tasks. Instead, the calling code has
to explicitly check the results of tasks and decide whether to re-raise
any exceptions.
This is a significant change, and so this commit introduces "OldTaskGroup",
which should behave as the TaskGroup class of old aiorpcx. All existing
usages of TaskGroup are replaced with OldTaskGroup.

closes https://github.com/spesmilo/electrum/issues/7446
patch-4
SomberNight 3 years ago
parent
commit
c9c094cfab
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 6
      contrib/deterministic-build/requirements.txt
  2. 2
      contrib/requirements/requirements.txt
  3. 6
      electrum/address_synchronizer.py
  4. 7
      electrum/bip39_recovery.py
  5. 10
      electrum/daemon.py
  6. 6
      electrum/exchange_rate.py
  7. 7
      electrum/interface.py
  8. 10
      electrum/lnpeer.py
  9. 12
      electrum/lnworker.py
  10. 18
      electrum/network.py
  11. 6
      electrum/synchronizer.py
  12. 38
      electrum/tests/test_lnpeer.py
  13. 9
      electrum/tests/test_lntransport.py
  14. 46
      electrum/util.py
  15. 8
      electrum/wallet.py
  16. 4
      run_electrum

6
contrib/deterministic-build/requirements.txt

@ -74,9 +74,9 @@ aiohttp==3.8.1 \
aiohttp-socks==0.7.1 \
--hash=sha256:2215cac4891ef3fa14b7d600ed343ed0f0a670c23b10e4142aa862b3db20341a \
--hash=sha256:94bcff5ef73611c6c6231c2ffc1be4af1599abec90dbd2fdbbd63233ec2fb0ff
aiorpcX==0.18.7 \
--hash=sha256:7fa48423e1c06cd0ffb7b60f2cca7e819b6cbbf57d4bc8a82944994ef5038f05 \
--hash=sha256:808a9ec9172df11677a0f7b459b69d1a6cf8b19c19da55541fa31fb1afce5ce7
aiorpcX==0.22.1 \
--hash=sha256:6026f7bed3432e206589c94dcf599be8cd85b5736b118c7275845c1bd922a553 \
--hash=sha256:e74f9fbed3fd21598e71fe05066618fc2c06feec504fe29490ddda05fdbdde62
aiosignal==1.2.0 \
--hash=sha256:26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a \
--hash=sha256:78ed67db6c7b7ced4f98e495e572106d5c432a93e1ddd1bf475e1dc05f5b7df2

2
contrib/requirements/requirements.txt

@ -1,7 +1,7 @@
qrcode
protobuf>=3.12
qdarkstyle>=2.7
aiorpcx>=0.18.7,<0.19
aiorpcx>=0.22.0,<0.23
aiohttp>=3.3.0,<4.0.0
aiohttp_socks>=0.3
certifi

6
electrum/address_synchronizer.py

@ -28,11 +28,9 @@ import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
from aiorpcx import TaskGroup
from . import bitcoin, util
from .bitcoin import COINBASE_MATURITY
from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException, with_lock
from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException, with_lock, OldTaskGroup
from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint, PartialTransaction
from .synchronizer import Synchronizer
from .verifier import SPV
@ -193,7 +191,7 @@ class AddressSynchronizer(Logger):
async def stop(self):
if self.network:
try:
async with TaskGroup() as group:
async with OldTaskGroup() as group:
if self.synchronizer:
await group.spawn(self.synchronizer.stop())
if self.verifier:

7
electrum/bip39_recovery.py

@ -4,20 +4,19 @@
from typing import TYPE_CHECKING
from aiorpcx import TaskGroup
from . import bitcoin
from .constants import BIP39_WALLET_FORMATS
from .bip32 import BIP32_PRIME, BIP32Node
from .bip32 import convert_bip32_path_to_list_of_uint32 as bip32_str_to_ints
from .bip32 import convert_bip32_intpath_to_strpath as bip32_ints_to_str
from .util import OldTaskGroup
if TYPE_CHECKING:
from .network import Network
async def account_discovery(network: 'Network', get_account_xpub):
async with TaskGroup() as group:
async with OldTaskGroup() as group:
account_scan_tasks = []
for wallet_format in BIP39_WALLET_FORMATS:
account_scan = scan_for_active_accounts(network, get_account_xpub, wallet_format)
@ -46,7 +45,7 @@ async def scan_for_active_accounts(network: 'Network', get_account_xpub, wallet_
async def account_has_history(network: 'Network', account_node: BIP32Node, script_type: str) -> bool:
gap_limit = 20
async with TaskGroup() as group:
async with OldTaskGroup() as group:
get_history_tasks = []
for address_index in range(gap_limit):
address_node = account_node.subkey_at_public_derivation("0/" + str(address_index))

10
electrum/daemon.py

@ -36,13 +36,13 @@ import json
import aiohttp
from aiohttp import web, client_exceptions
from aiorpcx import TaskGroup, timeout_after, TaskTimeout, ignore_after
from aiorpcx import timeout_after, TaskTimeout, ignore_after
from . import util
from .network import Network
from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare)
from .invoices import PR_PAID, PR_EXPIRED
from .util import log_exceptions, ignore_exceptions, randrange
from .util import log_exceptions, ignore_exceptions, randrange, OldTaskGroup
from .wallet import Wallet, Abstract_Wallet
from .storage import WalletStorage
from .wallet_db import WalletDB
@ -493,7 +493,7 @@ class Daemon(Logger):
self._stop_entered = False
self._stopping_soon_or_errored = threading.Event()
self._stopped_event = threading.Event()
self.taskgroup = TaskGroup()
self.taskgroup = OldTaskGroup()
asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs), self.asyncio_loop)
@log_exceptions
@ -591,12 +591,12 @@ class Daemon(Logger):
if self.gui_object:
self.gui_object.stop()
self.logger.info("stopping all wallets")
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for k, wallet in self._wallets.items():
await group.spawn(wallet.stop())
self.logger.info("stopping network and taskgroup")
async with ignore_after(2):
async with TaskGroup() as group:
async with OldTaskGroup() as group:
if self.network:
await group.spawn(self.network.stop(full_shutdown=True))
await group.spawn(self.taskgroup.cancel_remaining())

6
electrum/exchange_rate.py

@ -10,13 +10,13 @@ import decimal
from decimal import Decimal
from typing import Sequence, Optional
from aiorpcx.curio import timeout_after, TaskTimeout, TaskGroup
from aiorpcx.curio import timeout_after, TaskTimeout
import aiohttp
from . import util
from .bitcoin import COIN
from .i18n import _
from .util import (ThreadJob, make_dir, log_exceptions,
from .util import (ThreadJob, make_dir, log_exceptions, OldTaskGroup,
make_aiohttp_session, resource_path)
from .network import Network
from .simple_config import SimpleConfig
@ -449,7 +449,7 @@ def get_exchanges_and_currencies():
async def query_all_exchanges_for_their_ccys_over_network():
async with timeout_after(10):
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for name, klass in exchanges.items():
exchange = klass(None, None)
await group.spawn(get_currencies_safe(name, exchange))

7
electrum/interface.py

@ -38,7 +38,6 @@ import hashlib
import functools
import aiorpcx
from aiorpcx import TaskGroup
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
from aiorpcx.curio import timeout_after, TaskTimeout
from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
@ -47,7 +46,7 @@ import certifi
from .util import (ignore_exceptions, log_exceptions, bfh, MySocksProxy,
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
is_int_or_float, is_non_negative_int_or_float)
is_int_or_float, is_non_negative_int_or_float, OldTaskGroup)
from . import util
from . import x509
from . import pem
@ -376,7 +375,7 @@ class Interface(Logger):
# Dump network messages (only for this interface). Set at runtime from the console.
self.debug = False
self.taskgroup = TaskGroup()
self.taskgroup = OldTaskGroup()
async def spawn_task():
task = await self.network.taskgroup.spawn(self.run())
@ -675,7 +674,7 @@ class Interface(Logger):
async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
while True:
async with TaskGroup() as group:
async with OldTaskGroup() as group:
fee_tasks = []
for i in FEE_ETA_TARGETS:
fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))

10
electrum/lnpeer.py

@ -14,14 +14,14 @@ from datetime import datetime
import functools
import aiorpcx
from aiorpcx import TaskGroup, ignore_after
from aiorpcx import ignore_after
from .crypto import sha256, sha256d
from . import bitcoin, util
from . import ecc
from .ecc import sig_string_from_r_and_s, der_sig_from_sig_string
from . import constants
from .util import (bh2u, bfh, log_exceptions, ignore_exceptions, chunks, TaskGroup,
from .util import (bh2u, bfh, log_exceptions, ignore_exceptions, chunks, OldTaskGroup,
UnrelatedTransactionException)
from . import transaction
from .bitcoin import make_op_return
@ -105,7 +105,7 @@ class Peer(Logger):
self.announcement_signatures = defaultdict(asyncio.Queue)
self.orphan_channel_updates = OrderedDict() # type: OrderedDict[ShortChannelID, dict]
Logger.__init__(self)
self.taskgroup = TaskGroup()
self.taskgroup = OldTaskGroup()
# HTLCs offered by REMOTE, that we started removing but are still active:
self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]]
self.received_htlc_removed_event = asyncio.Event()
@ -1859,7 +1859,7 @@ class Peer(Logger):
# we can get triggered for events that happen on the downstream peer.
# TODO: trampoline forwarding relies on the polling
async with ignore_after(0.1):
async with TaskGroup(wait=any) as group:
async with OldTaskGroup(wait=any) as group:
await group.spawn(self._received_revack_event.wait())
await group.spawn(self.downstream_htlc_resolved_event.wait())
self._htlc_switch_iterstart_event.set()
@ -1943,7 +1943,7 @@ class Peer(Logger):
await self._htlc_switch_iterstart_event.wait()
await self._htlc_switch_iterdone_event.wait()
async with TaskGroup(wait=any) as group:
async with OldTaskGroup(wait=any) as group:
await group.spawn(htlc_switch_iteration())
await group.spawn(self.got_disconnected.wait())

12
electrum/lnworker.py

@ -22,11 +22,11 @@ import urllib.parse
import dns.resolver
import dns.exception
from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
from aiorpcx import run_in_thread, NetAddress, ignore_after
from . import constants, util
from . import keystore
from .util import profiler, chunks
from .util import profiler, chunks, OldTaskGroup
from .invoices import PR_TYPE_LN, PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, LNInvoice, LN_EXPIRY_NEVER
from .util import NetworkRetryManager, JsonRPCClient
from .lnutil import LN_MAX_FUNDING_SAT
@ -200,7 +200,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
self.backup_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.BACKUP_CIPHER).privkey
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
self.taskgroup = TaskGroup()
self.taskgroup = OldTaskGroup()
self.listen_server = None # type: Optional[asyncio.AbstractServer]
self.features = features
self.network = None # type: Optional[Network]
@ -767,13 +767,13 @@ class LNWallet(LNWorker):
# to wait a bit for it to become irrevocably removed.
# Note: we don't wait for *all htlcs* to get removed, only for those
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in self.peers.values():
await group.spawn(peer.wait_one_htlc_switch_iteration())
while True:
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
break
async with TaskGroup(wait=any) as group:
async with OldTaskGroup(wait=any) as group:
for peer in self.peers.values():
await group.spawn(peer.received_htlc_removed_event.wait())
@ -2269,7 +2269,7 @@ class LNWallet(LNWorker):
transport = LNTransport(privkey, peer_addr, proxy=self.network.proxy)
peer = Peer(self, node_id, transport, is_channel_backup=True)
try:
async with TaskGroup(wait=any) as group:
async with OldTaskGroup(wait=any) as group:
await group.spawn(peer._message_loop())
await group.spawn(peer.trigger_force_close(channel_id))
return

18
electrum/network.py

@ -40,11 +40,11 @@ import copy
import functools
import aiorpcx
from aiorpcx import TaskGroup, ignore_after
from aiorpcx import ignore_after
from aiohttp import ClientResponse
from . import util
from .util import (log_exceptions, ignore_exceptions,
from .util import (log_exceptions, ignore_exceptions, OldTaskGroup,
bfh, make_aiohttp_session, send_exception_to_crash_reporter,
is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager,
nullcontext)
@ -246,7 +246,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
LOGGING_SHORTCUT = 'n'
taskgroup: Optional[TaskGroup]
taskgroup: Optional[OldTaskGroup]
interface: Optional[Interface]
interfaces: Dict[ServerAddr, Interface]
_connecting_ifaces: Set[ServerAddr]
@ -462,7 +462,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
async def get_relay_fee():
self.relay_fee = await interface.get_relay_fee()
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(get_banner)
await group.spawn(get_donation_address)
await group.spawn(get_server_peers)
@ -839,7 +839,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
assert iface.ready.done(), "interface not ready yet"
# try actual request
try:
async with TaskGroup(wait=any) as group:
async with OldTaskGroup(wait=any) as group:
task = await group.spawn(func(self, *args, **kwargs))
await group.spawn(iface.got_disconnected.wait())
except RequestTimedOut:
@ -1184,7 +1184,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
async def _start(self):
assert not self.taskgroup
self.taskgroup = taskgroup = TaskGroup()
self.taskgroup = taskgroup = OldTaskGroup()
assert not self.interface and not self.interfaces
assert not self._connecting_ifaces
assert not self._closing_ifaces
@ -1225,7 +1225,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
# timeout: if full_shutdown, it is up to the caller to time us out,
# otherwise if e.g. restarting due to proxy changes, we time out fast
async with (nullcontext() if full_shutdown else ignore_after(1)):
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(self.taskgroup.cancel_remaining())
if full_shutdown:
await group.spawn(self.stop_gossip(full_shutdown=full_shutdown))
@ -1278,7 +1278,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
except asyncio.CancelledError:
# suppress spurious cancellations
group = self.taskgroup
if not group or group.closed():
if not group or group.joined:
raise
await asyncio.sleep(0.1)
@ -1352,7 +1352,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
except Exception as e:
res = e
responses[interface.server] = res
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for server in servers:
await group.spawn(get_response(server))
return responses

6
electrum/synchronizer.py

@ -28,11 +28,11 @@ from typing import Dict, List, TYPE_CHECKING, Tuple, Set
from collections import defaultdict
import logging
from aiorpcx import TaskGroup, run_in_thread, RPCError
from aiorpcx import run_in_thread, RPCError
from . import util
from .transaction import Transaction, PartialTransaction
from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer, random_shuffled_copy
from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer, random_shuffled_copy, OldTaskGroup
from .bitcoin import address_to_scripthash, is_address
from .logging import Logger
from .interface import GracefulDisconnect, NetworkTimeout
@ -218,7 +218,7 @@ class Synchronizer(SynchronizerBase):
self.requested_tx[tx_hash] = tx_height
if not transaction_hashes: return
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for tx_hash in transaction_hashes:
await group.spawn(self._get_transaction(tx_hash, allow_server_not_finding_tx=allow_server_not_finding_tx))

38
electrum/tests/test_lnpeer.py

@ -10,7 +10,7 @@ from concurrent import futures
import unittest
from typing import Iterable, NamedTuple, Tuple, List, Dict
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from aiorpcx import timeout_after, TaskTimeout
import electrum
import electrum.trampoline
@ -21,7 +21,7 @@ from electrum.ecc import ECPrivkey
from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh
from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh, OldTaskGroup
from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
@ -125,7 +125,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
self.taskgroup = TaskGroup()
self.taskgroup = OldTaskGroup()
self.lnwatcher = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
@ -365,7 +365,7 @@ class TestPeer(TestCaseForTestnet):
def tearDown(self):
async def cleanup_lnworkers():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for lnworker in self._lnworkers_created:
await group.spawn(lnworker.stop())
self._lnworkers_created.clear()
@ -569,7 +569,7 @@ class TestPeer(TestCaseForTestnet):
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
await group.spawn(p2._message_loop())
@ -643,7 +643,7 @@ class TestPeer(TestCaseForTestnet):
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
await group.spawn(p2._message_loop())
@ -667,10 +667,10 @@ class TestPeer(TestCaseForTestnet):
async with max_htlcs_in_flight:
await w1.pay_invoice(pay_req)
async def many_payments():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat))
for i in range(num_payments)]
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for pay_req_task in pay_reqs_tasks:
lnaddr, pay_req = pay_req_task.result()
await group.spawn(single_payment(pay_req))
@ -696,7 +696,7 @@ class TestPeer(TestCaseForTestnet):
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -740,7 +740,7 @@ class TestPeer(TestCaseForTestnet):
[edge.short_channel_id for edge in log[0].route])
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -764,7 +764,7 @@ class TestPeer(TestCaseForTestnet):
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -799,7 +799,7 @@ class TestPeer(TestCaseForTestnet):
self.assertEqual(500100000000, graph.channels[('dave', 'bob')].balance(LOCAL))
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -862,7 +862,7 @@ class TestPeer(TestCaseForTestnet):
raise PaymentDone()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -909,7 +909,7 @@ class TestPeer(TestCaseForTestnet):
raise NoPathFound()
async def f(kwargs):
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -948,7 +948,7 @@ class TestPeer(TestCaseForTestnet):
async def f():
await turn_on_trampoline_alice()
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -1026,7 +1026,7 @@ class TestPeer(TestCaseForTestnet):
raise SuccessfulTest()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -1196,7 +1196,7 @@ class TestPeer(TestCaseForTestnet):
raise SuccessfulTest()
async def f():
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for peer in [p1, p2]:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
@ -1223,7 +1223,7 @@ class TestPeer(TestCaseForTestnet):
failing_task = None
async def f():
nonlocal failing_task
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
failing_task = await group.spawn(p2._message_loop())
@ -1252,7 +1252,7 @@ class TestPeer(TestCaseForTestnet):
failing_task = None
async def f():
nonlocal failing_task
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
failing_task = await group.spawn(p2._message_loop())

9
electrum/tests/test_lntransport.py

@ -3,8 +3,7 @@ import asyncio
from electrum.ecc import ECPrivkey
from electrum.lnutil import LNPeerAddr
from electrum.lntransport import LNResponderTransport, LNTransport
from aiorpcx import TaskGroup
from electrum.util import OldTaskGroup
from . import ElectrumTestCase
from .test_bitcoin import needs_test_with_all_chacha20_implementations
@ -73,7 +72,7 @@ class TestLNTransport(ElectrumTestCase):
async def cb(reader, writer):
t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer)
self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes())
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(read_messages(t, messages_sent_by_client))
await group.spawn(write_messages(t, messages_sent_by_server))
responder_shaked.set()
@ -81,7 +80,7 @@ class TestLNTransport(ElectrumTestCase):
peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes())
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None)
await t.handshake()
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(read_messages(t, messages_sent_by_server))
await group.spawn(write_messages(t, messages_sent_by_client))
server_shaked.set()
@ -89,7 +88,7 @@ class TestLNTransport(ElectrumTestCase):
async def f():
server = await asyncio.start_server(cb, '127.0.0.1', 42898)
try:
async with TaskGroup() as group:
async with OldTaskGroup() as group:
await group.spawn(connect())
await group.spawn(responder_shaked.wait())
await group.spawn(server_shaked.wait())

46
electrum/util.py

@ -52,7 +52,6 @@ import attr
import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType
import aiorpcx
from aiorpcx import TaskGroup
import certifi
import dns.resolver
@ -1226,6 +1225,47 @@ def make_aiohttp_session(proxy: Optional[dict], headers=None, timeout=None):
return aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector)
class OldTaskGroup(aiorpcx.TaskGroup):
"""Automatically raises exceptions on join; as in aiorpcx prior to version 0.20.
That is, when using TaskGroup as a context manager, if any task encounters an exception,
we would like that exception to be re-raised (propagated out). For the wait=all case,
the OldTaskGroup class is emulating the following code-snippet:
```
async with TaskGroup() as group:
await group.spawn(task1())
await group.spawn(task2())
async for task in group:
if not task.cancelled():
task.result()
```
So instead of the above, one can just write:
```
async with OldTaskGroup() as group:
await group.spawn(task1())
await group.spawn(task2())
```
"""
async def join(self):
if self._wait is all:
exc = False
try:
async for task in self:
if not task.cancelled():
task.result()
except BaseException: # including asyncio.CancelledError
exc = True
raise
finally:
if exc:
await self.cancel_remaining()
await super().join()
else:
await super().join()
if self.completed:
self.completed.result()
class NetworkJobOnDefaultServer(Logger, ABC):
"""An abstract base class for a job that runs on the main network
interface. Every time the main interface changes, the job is
@ -1251,14 +1291,14 @@ class NetworkJobOnDefaultServer(Logger, ABC):
"""Initialise fields. Called every time the underlying
server connection changes.
"""
self.taskgroup = TaskGroup()
self.taskgroup = OldTaskGroup()
async def _start(self, interface: 'Interface'):
self.interface = interface
await interface.taskgroup.spawn(self._run_tasks(taskgroup=self.taskgroup))
@abstractmethod
async def _run_tasks(self, *, taskgroup: TaskGroup) -> None:
async def _run_tasks(self, *, taskgroup: OldTaskGroup) -> None:
"""Start tasks in taskgroup. Called every time the underlying
server connection changes.
"""

8
electrum/wallet.py

@ -46,13 +46,13 @@ import itertools
import threading
import enum
from aiorpcx import TaskGroup, timeout_after, TaskTimeout, ignore_after
from aiorpcx import timeout_after, TaskTimeout, ignore_after
from .i18n import _
from .bip32 import BIP32Node, convert_bip32_intpath_to_strpath, convert_bip32_path_to_list_of_uint32
from .crypto import sha256
from . import util
from .util import (NotEnoughFunds, UserCancelled, profiler,
from .util import (NotEnoughFunds, UserCancelled, profiler, OldTaskGroup,
format_satoshis, format_fee_satoshis, NoDynamicFeeEstimates,
WalletFileException, BitcoinException,
InvalidPassword, format_time, timestamp_to_datetime, Satoshis,
@ -134,7 +134,7 @@ async def _append_utxos_to_inputs(*, inputs: List[PartialTxInput], network: 'Net
inputs.append(txin)
u = await network.listunspent_for_scripthash(scripthash)
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for item in u:
if len(inputs) >= imax:
break
@ -155,7 +155,7 @@ async def sweep_preparations(privkeys, network: 'Network', imax=100):
inputs = [] # type: List[PartialTxInput]
keypairs = {}
async with TaskGroup() as group:
async with OldTaskGroup() as group:
for sec in privkeys:
txin_type, privkey, compressed = bitcoin.deserialize_privkey(sec)
await group.spawn(find_utxos_for_privkey(txin_type, privkey, compressed))

4
run_electrum

@ -63,8 +63,8 @@ def check_imports():
import aiorpcx
except ImportError as e:
sys.exit(f"Error: {str(e)}. Try 'sudo python3 -m pip install <module-name>'")
if not ((0, 18, 7) <= aiorpcx._version < (0, 19)):
raise RuntimeError(f'aiorpcX version {aiorpcx._version} does not match required: 0.18.7<=ver<0.19')
if not ((0, 22, 0) <= aiorpcx._version < (0, 23)):
raise RuntimeError(f'aiorpcX version {aiorpcx._version} does not match required: 0.22.0<=ver<0.23')
# the following imports are for pyinstaller
from google.protobuf import descriptor
from google.protobuf import message

Loading…
Cancel
Save