Browse Source

network: clean-up. make external API clear. rm interface_lock (mostly).

3.3.3.1
SomberNight 6 years ago
parent
commit
952e9b87e1
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/commands.py
  2. 7
      electrum/daemon.py
  3. 15
      electrum/gui/kivy/main_window.py
  4. 5
      electrum/gui/kivy/uix/dialogs/settings.py
  5. 2
      electrum/gui/kivy/uix/ui_screens/proxy.kv
  6. 2
      electrum/gui/kivy/uix/ui_screens/server.kv
  7. 3
      electrum/gui/qt/main_window.py
  8. 16
      electrum/gui/qt/network_dialog.py
  9. 3
      electrum/gui/stdio.py
  10. 6
      electrum/gui/text.py
  11. 44
      electrum/interface.py
  12. 421
      electrum/network.py
  13. 1
      electrum/plugin.py
  14. 11
      electrum/verifier.py

2
electrum/commands.py

@ -255,7 +255,7 @@ class Commands:
def broadcast(self, tx):
"""Broadcast a transaction to the network. """
tx = Transaction(tx)
return self.network.broadcast_transaction_from_non_network_thread(tx)
return self.network.run_from_another_thread(self.network.broadcast_transaction(tx))
@command('')
def createmultisig(self, num, pubkeys):

7
electrum/daemon.py

@ -28,11 +28,11 @@ import os
import time
import traceback
import sys
import threading
# from jsonrpc import JSONRPCResponseManager
import jsonrpclib
from .jsonrpc import VerifyingJSONRPCServer
from .jsonrpc import VerifyingJSONRPCServer
from .version import ELECTRUM_VERSION
from .network import Network
from .util import json_decode, DaemonThread
@ -129,7 +129,7 @@ class Daemon(DaemonThread):
self.network = Network(config)
self.fx = FxThread(config, self.network)
if self.network:
self.network.start(self.fx.run())
self.network.start([self.fx.run])
self.gui = None
self.wallets = {}
# Setup JSONRPC server
@ -308,6 +308,7 @@ class Daemon(DaemonThread):
gui_name = 'qt'
gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum'])
self.gui = gui.ElectrumGui(config, self, plugins)
threading.current_thread().setName('GUI')
try:
self.gui.main()
except BaseException as e:

15
electrum/gui/kivy/main_window.py

@ -16,6 +16,7 @@ from electrum.plugin import run_hook
from electrum.util import format_satoshis, format_satoshis_plain
from electrum.paymentrequest import PR_UNPAID, PR_PAID, PR_UNKNOWN, PR_EXPIRED
from electrum import blockchain
from electrum.network import Network
from .i18n import _
from kivy.app import App
@ -96,7 +97,7 @@ class ElectrumWindow(App):
def on_auto_connect(self, instance, x):
net_params = self.network.get_parameters()
net_params = net_params._replace(auto_connect=self.auto_connect)
self.network.set_parameters(net_params)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def toggle_auto_connect(self, x):
self.auto_connect = not self.auto_connect
@ -116,9 +117,10 @@ class ElectrumWindow(App):
from .uix.dialogs.choice_dialog import ChoiceDialog
chains = self.network.get_blockchains()
def cb(name):
for index, b in blockchain.blockchains.items():
with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items())
for index, b in blockchain_items:
if name == b.get_name():
self.network.follow_chain(index)
self.network.run_from_another_thread(self.network.follow_chain(index))
names = [blockchain.blockchains[b].get_name() for b in chains]
if len(names) > 1:
cur_chain = self.network.blockchain().get_name()
@ -265,7 +267,7 @@ class ElectrumWindow(App):
title = _('Electrum App')
self.electrum_config = config = kwargs.get('config', None)
self.language = config.get('language', 'en')
self.network = network = kwargs.get('network', None)
self.network = network = kwargs.get('network', None) # type: Network
if self.network:
self.num_blocks = self.network.get_local_height()
self.num_nodes = len(self.network.get_interfaces())
@ -708,7 +710,7 @@ class ElectrumWindow(App):
status = _("Offline")
elif self.network.is_connected():
server_height = self.network.get_server_height()
server_lag = self.network.get_local_height() - server_height
server_lag = self.num_blocks - server_height
if not self.wallet.up_to_date or server_height == 0:
status = _("Synchronizing...")
elif server_lag > 1:
@ -885,7 +887,8 @@ class ElectrumWindow(App):
Clock.schedule_once(lambda dt: on_success(tx))
def _broadcast_thread(self, tx, on_complete):
ok, txid = self.network.broadcast_transaction_from_non_network_thread(tx)
ok, txid = self.network.run_from_another_thread(
self.network.broadcast_transaction(tx))
Clock.schedule_once(lambda dt: on_complete(ok, txid))
def broadcast(self, tx, pr=None):

5
electrum/gui/kivy/uix/dialogs/settings.py

@ -159,8 +159,9 @@ class SettingsDialog(Factory.Popup):
return proxy.get('host') +':' + proxy.get('port') if proxy else _('None')
def proxy_dialog(self, item, dt):
network = self.app.network
if self._proxy_dialog is None:
net_params = self.app.network.get_parameters()
net_params = network.get_parameters()
proxy = net_params.proxy
def callback(popup):
nonlocal net_params
@ -175,7 +176,7 @@ class SettingsDialog(Factory.Popup):
else:
proxy = None
net_params = net_params._replace(proxy=proxy)
self.app.network.set_parameters(net_params)
network.run_from_another_thread(network.set_parameters(net_params))
item.status = self.proxy_status()
popup = Builder.load_file('electrum/gui/kivy/uix/ui_screens/proxy.kv')
popup.ids.mode.text = proxy.get('mode') if proxy else 'None'

2
electrum/gui/kivy/uix/ui_screens/proxy.kv

@ -72,6 +72,6 @@ Popup:
proxy['password']=str(root.ids.password.text)
if proxy['mode']=='none': proxy = None
net_params = net_params._replace(proxy=proxy)
app.network.set_parameters(net_params)
app.network.run_from_another_thread(app.network.set_parameters(net_params))
app.proxy_config = proxy if proxy else {}
nd.dismiss()

2
electrum/gui/kivy/uix/ui_screens/server.kv

@ -58,5 +58,5 @@ Popup:
on_release:
net_params = app.network.get_parameters()
net_params = net_params._replace(host=str(root.ids.host.text), port=str(root.ids.port.text))
app.network.set_parameters(net_params)
app.network.run_from_another_thread(app.network.set_parameters(net_params))
nd.dismiss()

3
electrum/gui/qt/main_window.py

@ -1635,7 +1635,8 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, PrintError):
if pr and pr.has_expired():
self.payment_request = None
return False, _("Payment request has expired")
status, msg = self.network.broadcast_transaction_from_non_network_thread(tx)
status, msg = self.network.run_from_another_thread(
self.network.broadcast_transaction(tx))
if pr and status is True:
self.invoices.set_paid(pr, tx.txid())
self.invoices.save()

16
electrum/gui/qt/network_dialog.py

@ -34,6 +34,7 @@ from electrum.i18n import _
from electrum import constants, blockchain
from electrum.util import print_error
from electrum.interface import serialize_server, deserialize_server
from electrum.network import Network
from .util import *
@ -97,7 +98,7 @@ class NodesListWidget(QTreeWidget):
pt.setX(50)
self.customContextMenuRequested.emit(pt)
def update(self, network):
def update(self, network: Network):
self.clear()
self.addChild = self.addTopLevelItem
chains = network.get_blockchains()
@ -187,7 +188,7 @@ class ServerListWidget(QTreeWidget):
class NetworkChoiceLayout(object):
def __init__(self, network, config, wizard=False):
def __init__(self, network: Network, config, wizard=False):
self.network = network
self.config = config
self.protocol = None
@ -361,7 +362,7 @@ class NetworkChoiceLayout(object):
status = _("Connected to {0} nodes.").format(n) if n else _("Not connected")
self.status_label.setText(status)
chains = self.network.get_blockchains()
if len(chains)>1:
if len(chains) > 1:
chain = self.network.blockchain()
forkpoint = chain.get_forkpoint()
name = chain.get_name()
@ -410,15 +411,14 @@ class NetworkChoiceLayout(object):
self.set_server()
def follow_branch(self, index):
self.network.follow_chain(index)
self.network.run_from_another_thread(self.network.follow_chain(index))
self.update()
def follow_server(self, server):
self.network.switch_to_interface(server)
net_params = self.network.get_parameters()
host, port, protocol = deserialize_server(server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
self.network.set_parameters(net_params)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
self.update()
def server_changed(self, x):
@ -451,7 +451,7 @@ class NetworkChoiceLayout(object):
net_params = net_params._replace(host=str(self.server_host.text()),
port=str(self.server_port.text()),
auto_connect=self.autoconnect_cb.isChecked())
self.network.set_parameters(net_params)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def set_proxy(self):
net_params = self.network.get_parameters()
@ -465,7 +465,7 @@ class NetworkChoiceLayout(object):
proxy = None
self.tor_cb.setChecked(False)
net_params = net_params._replace(proxy=proxy)
self.network.set_parameters(net_params)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def suggest_proxy(self, found_proxy):
self.tor_proxy = found_proxy

3
electrum/gui/stdio.py

@ -200,7 +200,8 @@ class ElectrumGui:
self.wallet.labels[tx.txid()] = self.str_description
print(_("Please wait..."))
status, msg = self.network.broadcast_transaction_from_non_network_thread(tx)
status, msg = self.network.run_from_another_thread(
self.network.broadcast_transaction(tx))
if status:
print(_('Payment sent.'))

6
electrum/gui/text.py

@ -365,7 +365,8 @@ class ElectrumGui:
self.wallet.labels[tx.txid()] = self.str_description
self.show_message(_("Please wait..."), getchar=False)
status, msg = self.network.broadcast_transaction_from_non_network_thread(tx)
status, msg = self.network.run_from_another_thread(
self.network.broadcast_transaction(tx))
if status:
self.show_message(_('Payment sent.'))
@ -410,7 +411,8 @@ class ElectrumGui:
return False
if out.get('server') or out.get('proxy'):
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
self.network.set_parameters(NetworkParameters(host, port, protocol, proxy, auto_connect))
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def settings_dialog(self):
fee = str(Decimal(self.config.fee_per_kb()) / COIN)

44
electrum/interface.py

@ -107,11 +107,7 @@ class NotificationSession(ClientSession):
class GracefulDisconnect(Exception): pass
class ErrorParsingSSLCert(Exception): pass
class ErrorGettingSSLCertFromServer(Exception): pass
@ -150,8 +146,11 @@ class Interface(PrintError):
self.tip_header = None
self.tip = 0
# TODO combine?
self.fut = asyncio.get_event_loop().create_task(self.run())
# note that an interface dying MUST NOT kill the whole network,
# hence exceptions raised by "run" need to be caught not to kill
# main_taskgroup! the aiosafe decorator does this.
asyncio.run_coroutine_threadsafe(
self.network.main_taskgroup.spawn(self.run()), self.network.asyncio_loop)
self.group = SilentTaskGroup()
def diagnostic_name(self):
@ -239,31 +238,29 @@ class Interface(PrintError):
sslc.check_hostname = 0
return sslc
def handle_graceful_disconnect(func):
def handle_disconnect(func):
async def wrapper_func(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except GracefulDisconnect as e:
self.print_error("disconnecting gracefully. {}".format(e))
self.exception = e
finally:
await self.network.connection_down(self.server)
return wrapper_func
@aiosafe
@handle_graceful_disconnect
@handle_disconnect
async def run(self):
try:
ssl_context = await self._get_ssl_context()
except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
self.exception = e
self.print_error('disconnecting due to: {} {}'.format(e, type(e)))
return
try:
await self.open_session(ssl_context, exit_early=False)
except (asyncio.CancelledError, OSError, aiorpcx.socks.SOCKSFailure) as e:
self.print_error('disconnecting due to: {} {}'.format(e, type(e)))
self.exception = e
return
# should never get here (can only exit via exception)
assert False
def mark_ready(self):
if self.ready.cancelled():
@ -352,9 +349,9 @@ class Interface(PrintError):
self.print_error("connection established. version: {}".format(ver))
async with self.group as group:
await group.spawn(self.ping())
await group.spawn(self.run_fetch_blocks())
await group.spawn(self.monitor_connection())
await group.spawn(self.ping)
await group.spawn(self.run_fetch_blocks)
await group.spawn(self.monitor_connection)
# NOTE: group.__aexit__ will be called here; this is needed to notice exceptions in the group!
async def monitor_connection(self):
@ -368,11 +365,8 @@ class Interface(PrintError):
await asyncio.sleep(300)
await self.session.send_request('server.ping')
def close(self):
async def job():
self.fut.cancel()
await self.group.cancel_remaining()
asyncio.run_coroutine_threadsafe(job(), self.network.asyncio_loop)
async def close(self):
await self.group.cancel_remaining()
async def run_fetch_blocks(self):
header_queue = asyncio.Queue()
@ -389,7 +383,7 @@ class Interface(PrintError):
self.mark_ready()
await self._process_header_at_tip()
self.network.trigger_callback('network_updated')
self.network.switch_lagging_interface()
await self.network.switch_lagging_interface()
async def _process_header_at_tip(self):
height, header = self.tip, self.tip_header
@ -517,7 +511,7 @@ class Interface(PrintError):
return 'fork_conflict', height
self.print_error('forkpoint conflicts with existing fork', branch.path())
self._raise_if_fork_conflicts_with_default_server(branch)
self._disconnect_from_interfaces_on_conflicting_blockchain(branch)
await self._disconnect_from_interfaces_on_conflicting_blockchain(branch)
branch.write(b'', 0)
branch.save_header(bad_header)
self.blockchain = branch
@ -543,8 +537,8 @@ class Interface(PrintError):
if chain_to_delete == chain_of_default_server:
raise GracefulDisconnect('refusing to overwrite blockchain of default server')
def _disconnect_from_interfaces_on_conflicting_blockchain(self, chain: Blockchain) -> None:
ifaces = self.network.disconnect_from_interfaces_on_given_blockchain(chain)
async def _disconnect_from_interfaces_on_conflicting_blockchain(self, chain: Blockchain) -> None:
ifaces = await self.network.disconnect_from_interfaces_on_given_blockchain(chain)
if not ifaces: return
servers = [interface.server for interface in ifaces]
self.print_error("forcing disconnect of other interfaces: {}".format(servers))

421
electrum/network.py

@ -32,18 +32,19 @@ import json
import sys
import ipaddress
import asyncio
from typing import NamedTuple, Optional, Sequence
from typing import NamedTuple, Optional, Sequence, List
import traceback
import dns
import dns.resolver
from aiorpcx import TaskGroup
from . import util
from .util import PrintError, print_error, aiosafe, bfh
from .util import PrintError, print_error, aiosafe, bfh, SilentTaskGroup
from .bitcoin import COIN
from . import constants
from . import blockchain
from .blockchain import Blockchain
from .blockchain import Blockchain, HEADER_SIZE
from .interface import Interface, serialize_server, deserialize_server
from .version import PROTOCOL_VERSION
from .simple_config import SimpleConfig
@ -160,14 +161,6 @@ INSTANCE = None
class Network(PrintError):
"""The Network class manages a set of connections to remote electrum
servers, each connected socket is handled by an Interface() object.
Connections are initiated by a Connection() thread which stops once
the connection succeeds or fails.
Our external API:
- Member functions get_header(), get_interfaces(), get_local_height(),
get_parameters(), get_server_height(), get_status_value(),
is_connected(), set_parameters(), stop()
"""
verbosity_filter = 'n'
@ -195,14 +188,18 @@ class Network(PrintError):
if not self.default_server:
self.default_server = pick_random_server()
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
self.main_taskgroup = None
self._jobs = []
# locks
self.restart_lock = asyncio.Lock()
self.bhi_lock = asyncio.Lock()
self.interface_lock = threading.RLock() # <- re-entrant
self.callback_lock = threading.Lock()
self.recent_servers_lock = threading.RLock() # <- re-entrant
self.interfaces_lock = threading.Lock() # for mutating/iterating self.interfaces
self.server_peers = {} # returned by interface (servers that the main interface knows about)
self.recent_servers = self.read_recent_servers() # note: needs self.recent_servers_lock
self.recent_servers = self._read_recent_servers() # note: needs self.recent_servers_lock
self.banner = ''
self.donation_address = ''
@ -219,26 +216,30 @@ class Network(PrintError):
# kick off the network. interface is the main server we are currently
# communicating with. interfaces is the set of servers we are connecting
# to or have an ongoing connection with
self.interface = None # note: needs self.interface_lock
self.interfaces = {} # note: needs self.interface_lock
self.interface = None
self.interfaces = {}
self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set()
self.server_queue = None
self.server_queue_group = None
self.proxy = None
self.asyncio_loop = asyncio.get_event_loop()
self.start_network(deserialize_server(self.default_server)[2],
deserialize_proxy(self.config.get('proxy')))
#self.asyncio_loop.set_debug(1)
self._run_forever = asyncio.Future()
self._thread = threading.Thread(target=self.asyncio_loop.run_until_complete,
args=(self._run_forever,),
name='Network')
self._thread.start()
def run_from_another_thread(self, coro):
assert self._thread != threading.current_thread(), 'must not be called from network thread'
fut = asyncio.run_coroutine_threadsafe(coro, self.asyncio_loop)
return fut.result()
@staticmethod
def get_instance():
return INSTANCE
def with_interface_lock(func):
def func_wrapper(self, *args, **kwargs):
with self.interface_lock:
return func(self, *args, **kwargs)
return func_wrapper
def with_recent_servers_lock(func):
def func_wrapper(self, *args, **kwargs):
with self.recent_servers_lock:
@ -266,7 +267,7 @@ class Network(PrintError):
else:
self.asyncio_loop.call_soon_threadsafe(callback, event, *args)
def read_recent_servers(self):
def _read_recent_servers(self):
if not self.config.path:
return []
path = os.path.join(self.config.path, "recent_servers")
@ -278,7 +279,7 @@ class Network(PrintError):
return []
@with_recent_servers_lock
def save_recent_servers(self):
def _save_recent_servers(self):
if not self.config.path:
return
path = os.path.join(self.config.path, "recent_servers")
@ -289,11 +290,11 @@ class Network(PrintError):
except:
pass
@with_interface_lock
def get_server_height(self):
return self.interface.tip if self.interface else 0
interface = self.interface
return interface.tip if interface else 0
def server_is_lagging(self):
async def _server_is_lagging(self):
sh = self.get_server_height()
if not sh:
self.print_error('no height for main interface')
@ -304,7 +305,7 @@ class Network(PrintError):
self.print_error('%s is lagging (%d vs %d)' % (self.default_server, sh, lh))
return result
def set_status(self, status):
def _set_status(self, status):
self.connection_status = status
self.notify('status')
@ -315,7 +316,7 @@ class Network(PrintError):
def is_connecting(self):
return self.connection_status == 'connecting'
async def request_server_info(self, interface):
async def _request_server_info(self, interface):
await interface.ready
session = interface.session
@ -340,9 +341,9 @@ class Network(PrintError):
await group.spawn(get_donation_address)
await group.spawn(get_server_peers)
await group.spawn(get_relay_fee)
await group.spawn(self.request_fee_estimates(interface))
await group.spawn(self._request_fee_estimates(interface))
async def request_fee_estimates(self, interface):
async def _request_fee_estimates(self, interface):
session = interface.session
from .simple_config import FEE_ETA_TARGETS
self.config.requested_fee_estimates()
@ -389,10 +390,10 @@ class Network(PrintError):
if self.is_connected():
return self.donation_address
@with_interface_lock
def get_interfaces(self):
'''The interfaces that are in connected state'''
return list(self.interfaces.keys())
def get_interfaces(self) -> List[str]:
"""The list of servers for the connected interfaces."""
with self.interfaces_lock:
return list(self.interfaces)
@with_recent_servers_lock
def get_servers(self):
@ -407,31 +408,31 @@ class Network(PrintError):
if host not in out:
out[host] = {protocol: port}
# add servers received from main interface
if self.server_peers:
out.update(filter_version(self.server_peers.copy()))
server_peers = self.server_peers
if server_peers:
out.update(filter_version(server_peers.copy()))
# potentially filter out some
if self.config.get('noonion'):
out = filter_noonion(out)
return out
@with_interface_lock
def start_interface(self, server):
def _start_interface(self, server):
if server not in self.interfaces and server not in self.connecting:
if server == self.default_server:
self.print_error("connecting to %s as new interface" % server)
self.set_status('connecting')
self._set_status('connecting')
self.connecting.add(server)
self.server_queue.put(server)
def start_random_interface(self):
with self.interface_lock:
def _start_random_interface(self):
with self.interfaces_lock:
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server:
self.start_interface(server)
self._start_interface(server)
return server
def set_proxy(self, proxy: Optional[dict]):
def _set_proxy(self, proxy: Optional[dict]):
self.proxy = proxy
# Store these somewhere so we can un-monkey-patch
if not hasattr(socket, "_getaddrinfo"):
@ -467,10 +468,10 @@ class Network(PrintError):
addr = str(answers[0])
else:
addr = host
except dns.exception.DNSException:
except dns.exception.DNSException as e:
# dns failed for some reason, e.g. dns.resolver.NXDOMAIN
# this is normal. Simply report back failure:
raise socket.gaierror(11001, 'getaddrinfo failed')
raise socket.gaierror(11001, 'getaddrinfo failed') from e
except BaseException as e:
# Possibly internal error in dnspython :( see #4483
# Fall back to original socket.getaddrinfo to resolve dns.
@ -478,48 +479,8 @@ class Network(PrintError):
addr = host
return socket._getaddrinfo(addr, *args, **kwargs)
@with_interface_lock
def start_network(self, protocol: str, proxy: Optional[dict]):
assert not self.interface and not self.interfaces
assert not self.connecting and not self.server_queue
assert not self.server_queue_group
self.print_error('starting network')
self.disconnected_servers = set([]) # note: needs self.interface_lock
self.protocol = protocol
self._init_server_queue()
self.set_proxy(proxy)
self.start_interface(self.default_server)
self.trigger_callback('network_updated')
def _init_server_queue(self):
self.server_queue = queue.Queue()
self.server_queue_group = server_queue_group = TaskGroup()
async def job():
forever = asyncio.Event()
async with server_queue_group as group:
await group.spawn(forever.wait())
asyncio.run_coroutine_threadsafe(job(), self.asyncio_loop)
@with_interface_lock
def stop_network(self):
self.print_error("stopping network")
for interface in list(self.interfaces.values()):
self.close_interface(interface)
if self.interface:
self.close_interface(self.interface)
assert self.interface is None
assert not self.interfaces
self.connecting.clear()
self._stop_server_queue()
self.trigger_callback('network_updated')
def _stop_server_queue(self):
# Get a new queue - no old pending connections thanks!
self.server_queue = None
asyncio.run_coroutine_threadsafe(self.server_queue_group.cancel_remaining(), self.asyncio_loop)
self.server_queue_group = None
def set_parameters(self, net_params: NetworkParameters):
@aiosafe
async def set_parameters(self, net_params: NetworkParameters):
proxy = net_params.proxy
proxy_str = serialize_proxy(proxy)
host, port, protocol = net_params.host, net_params.port, net_params.protocol
@ -538,30 +499,30 @@ class Network(PrintError):
# abort if changes were not allowed by config
if self.config.get('server') != server_str or self.config.get('proxy') != proxy_str:
return
self.auto_connect = net_params.auto_connect
if self.proxy != proxy or self.protocol != protocol:
# Restart the network defaulting to the given server
with self.interface_lock:
self.stop_network()
async with self.restart_lock:
self.auto_connect = net_params.auto_connect
if self.proxy != proxy or self.protocol != protocol:
# Restart the network defaulting to the given server
await self._stop()
self.default_server = server_str
self.start_network(protocol, proxy)
elif self.default_server != server_str:
self.switch_to_interface(server_str)
else:
self.switch_lagging_interface()
await self._start()
elif self.default_server != server_str:
await self.switch_to_interface(server_str)
else:
await self.switch_lagging_interface()
def switch_to_random_interface(self):
async def _switch_to_random_interface(self):
'''Switch to a random connected server other than the current one'''
servers = self.get_interfaces() # Those in connected state
if self.default_server in servers:
servers.remove(self.default_server)
if servers:
self.switch_to_interface(random.choice(servers))
await self.switch_to_interface(random.choice(servers))
@with_interface_lock
def switch_lagging_interface(self):
async def switch_lagging_interface(self):
'''If auto_connect and lagging, switch interface'''
if self.server_is_lagging() and self.auto_connect:
if await self._server_is_lagging() and self.auto_connect:
# switch to one that has the correct header (not height)
header = self.blockchain().read_header(self.get_local_height())
def filt(x):
@ -569,111 +530,105 @@ class Network(PrintError):
b = header
assert type(a) is type(b)
return a == b
filtered = list(map(lambda x: x[0], filter(filt, self.interfaces.items())))
with self.interfaces_lock: interfaces_items = list(self.interfaces.items())
filtered = list(map(lambda x: x[0], filter(filt, interfaces_items)))
if filtered:
choice = random.choice(filtered)
self.switch_to_interface(choice)
@with_interface_lock
def switch_to_interface(self, server):
'''Switch to server as our interface. If no connection exists nor
being opened, start a thread to connect. The actual switch will
happen on receipt of the connection notification. Do nothing
if server already is our interface.'''
await self.switch_to_interface(choice)
async def switch_to_interface(self, server: str):
"""Switch to server as our main interface. If no connection exists,
queue interface to be started. The actual switch will
happen when the interface becomes ready.
"""
self.default_server = server
old_interface = self.interface
old_server = old_interface.server if old_interface else None
# Stop any current interface in order to terminate subscriptions,
# and to cancel tasks in interface.group.
# However, for headers sub, give preference to this interface
# over unknown ones, i.e. start it again right away.
if old_server and old_server != server:
await self._close_interface(old_interface)
if len(self.interfaces) <= self.num_server:
self._start_interface(old_server)
if server not in self.interfaces:
self.interface = None
self.start_interface(server)
self._start_interface(server)
return
i = self.interfaces[server]
if self.interface != i:
if old_interface != i:
self.print_error("switching to", server)
blockchain_updated = False
if self.interface is not None:
blockchain_updated = i.blockchain != self.interface.blockchain
# Stop any current interface in order to terminate subscriptions,
# and to cancel tasks in interface.group.
# However, for headers sub, give preference to this interface
# over unknown ones, i.e. start it again right away.
old_server = self.interface.server
self.close_interface(self.interface)
if old_server != server and len(self.interfaces) <= self.num_server:
self.start_interface(old_server)
blockchain_updated = i.blockchain != self.blockchain()
self.interface = i
asyncio.run_coroutine_threadsafe(
i.group.spawn(self.request_server_info(i)), self.asyncio_loop)
await i.group.spawn(self._request_server_info(i))
self.trigger_callback('default_server_changed')
self.set_status('connected')
self._set_status('connected')
self.trigger_callback('network_updated')
if blockchain_updated: self.trigger_callback('blockchain_updated')
@with_interface_lock
def close_interface(self, interface):
async def _close_interface(self, interface):
if interface:
if interface.server in self.interfaces:
self.interfaces.pop(interface.server)
with self.interfaces_lock:
if self.interfaces.get(interface.server) == interface:
self.interfaces.pop(interface.server)
if interface.server == self.default_server:
self.interface = None
interface.close()
await interface.close()
@with_recent_servers_lock
def add_recent_server(self, server):
def _add_recent_server(self, server):
# list is ordered
if server in self.recent_servers:
self.recent_servers.remove(server)
self.recent_servers.insert(0, server)
self.recent_servers = self.recent_servers[0:20]
self.save_recent_servers()
self._save_recent_servers()
@with_interface_lock
def connection_down(self, server):
async def connection_down(self, server):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
self.disconnected_servers.add(server)
if server == self.default_server:
self.set_status('disconnected')
if server in self.interfaces:
self.close_interface(self.interfaces[server])
self._set_status('disconnected')
interface = self.interfaces.get(server, None)
if interface:
await self._close_interface(interface)
self.trigger_callback('network_updated')
@aiosafe
async def new_interface(self, server):
async def _run_new_interface(self, server):
interface = Interface(self, server, self.config.path, self.proxy)
timeout = 10 if not self.proxy else 20
try:
await asyncio.wait_for(interface.ready, timeout)
except BaseException as e:
#import traceback
#traceback.print_exc()
self.print_error(server, "couldn't launch because", str(e), str(type(e)))
# note: connection_down will not call interface.close() as
# interface is not yet in self.interfaces. OTOH, calling
# interface.close() here will sometimes raise deep inside the
# asyncio internal select.select... instead, interface will close
# itself when it detects the cancellation of interface.ready;
# however this might take several seconds...
self.connection_down(server)
await interface.close()
return
else:
with self.interface_lock:
with self.interfaces_lock:
assert server not in self.interfaces
self.interfaces[server] = interface
finally:
with self.interface_lock:
try: self.connecting.remove(server)
except KeyError: pass
try: self.connecting.remove(server)
except KeyError: pass
if server == self.default_server:
self.switch_to_interface(server)
await self.switch_to_interface(server)
self.add_recent_server(server)
self._add_recent_server(server)
self.trigger_callback('network_updated')
def init_headers_file(self):
async def _init_headers_file(self):
b = blockchain.blockchains[0]
filename = b.path()
length = 80 * len(constants.net.CHECKPOINTS) * 2016
length = HEADER_SIZE * len(constants.net.CHECKPOINTS) * 2016
if not os.path.exists(filename) or os.path.getsize(filename) < length:
with open(filename, 'wb') as f:
if length > 0:
@ -686,11 +641,6 @@ class Network(PrintError):
async def get_merkle_for_transaction(self, tx_hash, tx_height):
return await self.interface.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
def broadcast_transaction_from_non_network_thread(self, tx, timeout=10):
# note: calling this from the network thread will deadlock it
fut = asyncio.run_coroutine_threadsafe(self.broadcast_transaction(tx, timeout=timeout), self.asyncio_loop)
return fut.result()
async def broadcast_transaction(self, tx, timeout=10):
try:
out = await self.interface.session.send_request('blockchain.transaction.broadcast', [str(tx)], timeout=timeout)
@ -706,101 +656,124 @@ class Network(PrintError):
async def request_chunk(self, height, tip=None, *, can_return_early=False):
return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early)
@with_interface_lock
def blockchain(self):
if self.interface and self.interface.blockchain is not None:
self.blockchain_index = self.interface.blockchain.forkpoint
interface = self.interface
if interface and interface.blockchain is not None:
self.blockchain_index = interface.blockchain.forkpoint
return blockchain.blockchains[self.blockchain_index]
@with_interface_lock
def get_blockchains(self):
out = {} # blockchain_id -> list(interfaces)
with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items())
with self.interfaces_lock: interfaces_values = list(self.interfaces.values())
for chain_id, bc in blockchain_items:
r = list(filter(lambda i: i.blockchain==bc, list(self.interfaces.values())))
r = list(filter(lambda i: i.blockchain==bc, interfaces_values))
if r:
out[chain_id] = r
return out
@with_interface_lock
def disconnect_from_interfaces_on_given_blockchain(self, chain: Blockchain) -> Sequence[Interface]:
async def disconnect_from_interfaces_on_given_blockchain(self, chain: Blockchain) -> Sequence[Interface]:
chain_id = chain.forkpoint
ifaces = self.get_blockchains().get(chain_id) or []
for interface in ifaces:
self.connection_down(interface.server)
await self.connection_down(interface.server)
return ifaces
def follow_chain(self, index):
bc = blockchain.blockchains.get(index)
async def follow_chain(self, chain_id):
bc = blockchain.blockchains.get(chain_id)
if bc:
self.blockchain_index = index
self.config.set_key('blockchain_index', index)
with self.interface_lock:
interfaces = list(self.interfaces.values())
for i in interfaces:
if i.blockchain == bc:
self.switch_to_interface(i.server)
self.blockchain_index = chain_id
self.config.set_key('blockchain_index', chain_id)
with self.interfaces_lock: interfaces_values = list(self.interfaces.values())
for iface in interfaces_values:
if iface.blockchain == bc:
await self.switch_to_interface(iface.server)
break
else:
raise Exception('blockchain not found', index)
raise Exception('blockchain not found', chain_id)
with self.interface_lock:
if self.interface:
net_params = self.get_parameters()
host, port, protocol = deserialize_server(self.interface.server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
self.set_parameters(net_params)
if self.interface:
net_params = self.get_parameters()
host, port, protocol = deserialize_server(self.interface.server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
await self.set_parameters(net_params)
def get_local_height(self):
return self.blockchain().height()
def export_checkpoints(self, path):
# run manually from the console to generate checkpoints
"""Run manually to generate blockchain checkpoints.
Kept for console use only.
"""
cp = self.blockchain().get_checkpoints()
with open(path, 'w', encoding='utf-8') as f:
f.write(json.dumps(cp, indent=4))
def start(self, fx=None):
self.main_taskgroup = TaskGroup()
async def _start(self, jobs=None):
if jobs is None: jobs = self._jobs
self._jobs = jobs
assert not self.main_taskgroup
self.main_taskgroup = SilentTaskGroup()
async def main():
self.init_headers_file()
async with self.main_taskgroup as group:
await group.spawn(self.maintain_sessions())
if fx: await group.spawn(fx)
self._wrapper_thread = threading.Thread(target=self.asyncio_loop.run_until_complete, args=(main(),))
self._wrapper_thread.start()
try:
await self._init_headers_file()
async with self.main_taskgroup as group:
await group.spawn(self._maintain_sessions())
[await group.spawn(job) for job in jobs]
except Exception as e:
traceback.print_exc(file=sys.stderr)
raise e
asyncio.run_coroutine_threadsafe(main(), self.asyncio_loop)
assert not self.interface and not self.interfaces
assert not self.connecting and not self.server_queue
self.print_error('starting network')
self.disconnected_servers = set([])
self.protocol = deserialize_server(self.default_server)[2]
self.server_queue = queue.Queue()
self._set_proxy(deserialize_proxy(self.config.get('proxy')))
self._start_interface(self.default_server)
self.trigger_callback('network_updated')
def start(self, jobs=None):
asyncio.run_coroutine_threadsafe(self._start(jobs=jobs), self.asyncio_loop)
async def _stop(self, full_shutdown=False):
self.print_error("stopping network")
try:
asyncio.wait_for(await self.main_taskgroup.cancel_remaining(), timeout=2)
except asyncio.TimeoutError: pass
self.main_taskgroup = None
assert self.interface is None
assert not self.interfaces
self.connecting.clear()
self.server_queue = None
self.trigger_callback('network_updated')
if full_shutdown:
self._run_forever.set_result(1)
def stop(self):
asyncio.run_coroutine_threadsafe(self.main_taskgroup.cancel_remaining(), self.asyncio_loop)
assert self._thread != threading.current_thread(), 'must not be called from network thread'
fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
fut.result()
def join(self):
self._wrapper_thread.join(1)
self._thread.join(1)
async def maintain_sessions(self):
async def _maintain_sessions(self):
while True:
# launch already queued up new interfaces
while self.server_queue.qsize() > 0:
server = self.server_queue.get()
await self.server_queue_group.spawn(self.new_interface(server))
remove = []
for k, i in self.interfaces.items():
if i.fut.done() and not i.exception:
assert False, "interface future should not finish without exception"
if i.exception:
if not i.fut.done():
try: i.fut.cancel()
except Exception as e: self.print_error('exception while cancelling fut', e)
try:
raise i.exception
except BaseException as e:
self.print_error(i.server, "errored because:", str(e), str(type(e)))
remove.append(k)
for k in remove:
self.connection_down(k)
# nodes
await self.main_taskgroup.spawn(self._run_new_interface(server))
# maybe queue new interfaces to be launched later
now = time.time()
for i in range(self.num_server - len(self.interfaces) - len(self.connecting)):
self.start_random_interface()
self._start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
self.disconnected_servers = set([])
@ -810,16 +783,16 @@ class Network(PrintError):
if not self.is_connected():
if self.auto_connect:
if not self.is_connecting():
self.switch_to_random_interface()
await self._switch_to_random_interface()
else:
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
self.server_retry_time = now
else:
self.switch_to_interface(self.default_server)
await self.switch_to_interface(self.default_server)
else:
if self.config.is_fee_estimates_update_required():
await self.interface.group.spawn(self.request_fee_estimates, self.interface)
await self.interface.group.spawn(self._request_fee_estimates, self.interface)
await asyncio.sleep(0.1)

1
electrum/plugin.py

@ -47,6 +47,7 @@ class Plugins(DaemonThread):
@profiler
def __init__(self, config, is_local, gui_name):
DaemonThread.__init__(self)
self.setName('Plugins')
self.pkgpath = os.path.dirname(plugins.__file__)
self.config = config
self.hw_wallets = {}

11
electrum/verifier.py

@ -47,7 +47,6 @@ class SPV(PrintError):
def __init__(self, network, wallet):
self.wallet = wallet
self.network = network
self.blockchain = network.blockchain()
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
self.requested_merkle = set() # txid set of pending requests
@ -55,18 +54,14 @@ class SPV(PrintError):
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
async def main(self, group: TaskGroup):
self.blockchain = self.network.blockchain()
while True:
await self._maybe_undo_verifications()
await self._request_proofs(group)
await asyncio.sleep(0.1)
async def _request_proofs(self, group: TaskGroup):
blockchain = self.network.blockchain()
if not blockchain:
self.print_error("no blockchain")
return
local_height = self.network.get_local_height()
local_height = self.blockchain.height()
unverified = self.wallet.get_unverified_txs()
for tx_hash, tx_height in unverified.items():
@ -77,7 +72,7 @@ class SPV(PrintError):
if tx_height <= 0 or tx_height > local_height:
continue
# if it's in the checkpoint region, we still might not have the header
header = blockchain.read_header(tx_height)
header = self.blockchain.read_header(tx_height)
if header is None:
if tx_height < constants.net.max_checkpoint():
await group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))

Loading…
Cancel
Save