Browse Source

network: replace "server" strings with ServerAddr objects

patch-3
SomberNight 5 years ago
parent
commit
cf1f2ba4dc
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/daemon.py
  2. 2
      electrum/exchange_rate.py
  3. 40
      electrum/gui/qt/network_dialog.py
  4. 24
      electrum/gui/text.py
  5. 83
      electrum/interface.py
  6. 124
      electrum/network.py

2
electrum/daemon.py

@ -270,6 +270,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError):
class Daemon(Logger): class Daemon(Logger):
network: Optional[Network]
@profiler @profiler
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
Logger.__init__(self) Logger.__init__(self)

2
electrum/exchange_rate.py

@ -453,7 +453,7 @@ def get_exchanges_by_ccy(history=True):
class FxThread(ThreadJob): class FxThread(ThreadJob):
def __init__(self, config: SimpleConfig, network: Network): def __init__(self, config: SimpleConfig, network: Optional[Network]):
ThreadJob.__init__(self) ThreadJob.__init__(self)
self.config = config self.config = config
self.network = network self.network = network

40
electrum/gui/qt/network_dialog.py

@ -36,7 +36,7 @@ from PyQt5.QtGui import QFontMetrics
from electrum.i18n import _ from electrum.i18n import _
from electrum import constants, blockchain, util from electrum import constants, blockchain, util
from electrum.interface import serialize_server, deserialize_server from electrum.interface import ServerAddr
from electrum.network import Network from electrum.network import Network
from electrum.logging import get_logger from electrum.logging import get_logger
@ -72,10 +72,13 @@ class NetworkDialog(QDialog):
class NodesListWidget(QTreeWidget): class NodesListWidget(QTreeWidget):
SERVER_ADDR_ROLE = Qt.UserRole + 100
CHAIN_ID_ROLE = Qt.UserRole + 101
IS_SERVER_ROLE = Qt.UserRole + 102
def __init__(self, parent): def __init__(self, parent):
QTreeWidget.__init__(self) QTreeWidget.__init__(self)
self.parent = parent self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Connected node'), _('Height')]) self.setHeaderLabels([_('Connected node'), _('Height')])
self.setContextMenuPolicy(Qt.CustomContextMenu) self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu) self.customContextMenuRequested.connect(self.create_menu)
@ -84,13 +87,13 @@ class NodesListWidget(QTreeWidget):
item = self.currentItem() item = self.currentItem()
if not item: if not item:
return return
is_server = not bool(item.data(0, Qt.UserRole)) is_server = bool(item.data(0, self.IS_SERVER_ROLE))
menu = QMenu() menu = QMenu()
if is_server: if is_server:
server = item.data(1, Qt.UserRole) server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server)) menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
else: else:
chain_id = item.data(1, Qt.UserRole) chain_id = item.data(0, self.CHAIN_ID_ROLE)
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id)) menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
menu.exec_(self.viewport().mapToGlobal(position)) menu.exec_(self.viewport().mapToGlobal(position))
@ -117,15 +120,15 @@ class NodesListWidget(QTreeWidget):
name = b.get_name() name = b.get_name()
if n_chains > 1: if n_chains > 1:
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()]) x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
x.setData(0, Qt.UserRole, 1) x.setData(0, self.IS_SERVER_ROLE, 0)
x.setData(1, Qt.UserRole, b.get_id()) x.setData(0, self.CHAIN_ID_ROLE, b.get_id())
else: else:
x = self x = self
for i in interfaces: for i in interfaces:
star = ' *' if i == network.interface else '' star = ' *' if i == network.interface else ''
item = QTreeWidgetItem([i.host + star, '%d'%i.tip]) item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
item.setData(0, Qt.UserRole, 0) item.setData(0, self.IS_SERVER_ROLE, 1)
item.setData(1, Qt.UserRole, i.server) item.setData(0, self.SERVER_ADDR_ROLE, i.server)
x.addChild(item) x.addChild(item)
if n_chains > 1: if n_chains > 1:
self.addTopLevelItem(x) self.addTopLevelItem(x)
@ -144,11 +147,11 @@ class ServerListWidget(QTreeWidget):
HOST = 0 HOST = 0
PORT = 1 PORT = 1
SERVER_STR_ROLE = Qt.UserRole + 100 SERVER_ADDR_ROLE = Qt.UserRole + 100
def __init__(self, parent): def __init__(self, parent):
QTreeWidget.__init__(self) QTreeWidget.__init__(self)
self.parent = parent self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Host'), _('Port')]) self.setHeaderLabels([_('Host'), _('Port')])
self.setContextMenuPolicy(Qt.CustomContextMenu) self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu) self.customContextMenuRequested.connect(self.create_menu)
@ -158,14 +161,13 @@ class ServerListWidget(QTreeWidget):
if not item: if not item:
return return
menu = QMenu() menu = QMenu()
server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE) server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE)
menu.addAction(_("Use as server"), lambda: self.set_server(server)) menu.addAction(_("Use as server"), lambda: self.set_server(server))
menu.exec_(self.viewport().mapToGlobal(position)) menu.exec_(self.viewport().mapToGlobal(position))
def set_server(self, s): def set_server(self, server: ServerAddr):
host, port, protocol = deserialize_server(s) self.parent.server_host.setText(server.host)
self.parent.server_host.setText(host) self.parent.server_port.setText(str(server.port))
self.parent.server_port.setText(port)
self.parent.set_server() self.parent.set_server()
def keyPressEvent(self, event): def keyPressEvent(self, event):
@ -188,8 +190,8 @@ class ServerListWidget(QTreeWidget):
port = d.get(protocol) port = d.get(protocol)
if port: if port:
x = QTreeWidgetItem([_host, port]) x = QTreeWidgetItem([_host, port])
server = serialize_server(_host, port, protocol) server = ServerAddr(_host, port, protocol=protocol)
x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server) x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server)
self.addTopLevelItem(x) self.addTopLevelItem(x)
h = self.header() h = self.header()
@ -431,7 +433,7 @@ class NetworkChoiceLayout(object):
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id)) self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
self.update() self.update()
def follow_server(self, server): def follow_server(self, server: ServerAddr):
self.network.run_from_another_thread(self.network.follow_chain_given_server(server)) self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
self.update() self.update()

24
electrum/gui/text.py

@ -6,6 +6,7 @@ import locale
from decimal import Decimal from decimal import Decimal
import getpass import getpass
import logging import logging
from typing import TYPE_CHECKING
import electrum import electrum
from electrum import util from electrum import util
@ -15,15 +16,21 @@ from electrum.transaction import PartialTxOutput
from electrum.wallet import Wallet from electrum.wallet import Wallet
from electrum.storage import WalletStorage from electrum.storage import WalletStorage
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
from electrum.interface import deserialize_server from electrum.interface import ServerAddr
from electrum.logging import console_stderr_handler from electrum.logging import console_stderr_handler
if TYPE_CHECKING:
from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig
from electrum.plugin import Plugins
_ = lambda x:x # i18n _ = lambda x:x # i18n
class ElectrumGui: class ElectrumGui:
def __init__(self, config, daemon, plugins): def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'):
self.config = config self.config = config
self.network = daemon.network self.network = daemon.network
@ -404,21 +411,24 @@ class ElectrumGui:
net_params = self.network.get_parameters() net_params = self.network.get_parameters()
host, port, protocol = net_params.host, net_params.port, net_params.protocol host, port, protocol = net_params.host, net_params.port, net_params.protocol
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
srv = 'auto-connect' if auto_connect else self.network.default_server srv = 'auto-connect' if auto_connect else str(self.network.default_server)
out = self.run_dialog('Network', [ out = self.run_dialog('Network', [
{'label':'server', 'type':'str', 'value':srv}, {'label':'server', 'type':'str', 'value':srv},
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')}, {'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
], buttons = 1) ], buttons = 1)
if out: if out:
if out.get('server'): if out.get('server'):
server = out.get('server') server_str = out.get('server')
auto_connect = server == 'auto-connect' auto_connect = server_str == 'auto-connect'
if not auto_connect: if not auto_connect:
try: try:
host, port, protocol = deserialize_server(server) server_addr = ServerAddr.from_str(server_str)
except Exception: except Exception:
self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"") self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"")
return False return False
host = server_addr.host
port = str(server_addr.port)
protocol = server_addr.protocol
if out.get('server') or out.get('proxy'): if out.get('server') or out.get('proxy'):
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect) net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)

83
electrum/interface.py

@ -29,7 +29,7 @@ import sys
import traceback import traceback
import asyncio import asyncio
import socket import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
from collections import defaultdict from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
import itertools import itertools
@ -198,22 +198,57 @@ class _RSClient(RSClient):
raise ConnectError(e) from e raise ConnectError(e) from e
def deserialize_server(server_str: str) -> Tuple[str, str, str]: class ServerAddr:
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(server_str).rsplit(':', 2)
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
if protocol not in ('s', 't'):
raise ValueError('invalid network protocol: {}'.format(protocol))
net_addr = NetAddress(host, port) # this validates host and port
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
return host, port, protocol
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
assert isinstance(host, str), repr(host)
if protocol is None:
protocol = 's'
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
try:
net_addr = NetAddress(host, port) # this validates host and port
except Exception as e:
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
if protocol not in ('s', 't'):
raise ValueError(f"invalid network protocol: {protocol}")
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
self.port = int(net_addr.port)
self.protocol = protocol
self._net_addr_str = str(net_addr)
@classmethod
def from_str(cls, s: str) -> 'ServerAddr':
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(s).rsplit(':', 2)
return ServerAddr(host=host, port=port, protocol=protocol)
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str: def __str__(self):
return str(':'.join([host, str(port), protocol])) return '{}:{}'.format(self.net_addr_str(), self.protocol)
def to_json(self) -> str:
return str(self)
def __repr__(self):
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
def net_addr_str(self) -> str:
return self._net_addr_str
def __eq__(self, other):
if not isinstance(other, ServerAddr):
return False
return (self.host == other.host
and self.port == other.port
and self.protocol == other.protocol)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.host, self.port, self.protocol))
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str: def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
@ -232,12 +267,10 @@ class Interface(Logger):
LOGGING_SHORTCUT = 'i' LOGGING_SHORTCUT = 'i'
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]): def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
self.ready = asyncio.Future() self.ready = asyncio.Future()
self.got_disconnected = asyncio.Future() self.got_disconnected = asyncio.Future()
self.server = server self.server = server
self.host, self.port, self.protocol = deserialize_server(self.server)
self.port = int(self.port)
Logger.__init__(self) Logger.__init__(self)
assert network.config.path assert network.config.path
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host) self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
@ -259,8 +292,20 @@ class Interface(Logger):
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop) self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
self.taskgroup = SilentTaskGroup() self.taskgroup = SilentTaskGroup()
@property
def host(self):
return self.server.host
@property
def port(self):
return self.server.port
@property
def protocol(self):
return self.server.protocol
def diagnostic_name(self): def diagnostic_name(self):
return str(NetAddress(self.host, self.port)) return self.server.net_addr_str()
def __str__(self): def __str__(self):
return f"<Interface {self.diagnostic_name()}>" return f"<Interface {self.diagnostic_name()}>"

124
electrum/network.py

@ -32,7 +32,7 @@ import socket
import json import json
import sys import sys
import asyncio import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
import traceback import traceback
import concurrent import concurrent
from concurrent import futures from concurrent import futures
@ -44,7 +44,7 @@ from aiohttp import ClientResponse
from . import util from . import util
from .util import (log_exceptions, ignore_exceptions, from .util import (log_exceptions, ignore_exceptions,
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter, bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
is_hash256_str, is_non_negative_integer) is_hash256_str, is_non_negative_integer, MyEncoder)
from .bitcoin import COIN from .bitcoin import COIN
from . import constants from . import constants
@ -53,9 +53,9 @@ from . import bitcoin
from . import dns_hacks from . import dns_hacks
from .transaction import Transaction from .transaction import Transaction
from .blockchain import Blockchain, HEADER_SIZE from .blockchain import Blockchain, HEADER_SIZE
from .interface import (Interface, serialize_server, deserialize_server, from .interface import (Interface,
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS, RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
NetworkException, RequestCorrupted) NetworkException, RequestCorrupted, ServerAddr)
from .version import PROTOCOL_VERSION from .version import PROTOCOL_VERSION
from .simple_config import SimpleConfig from .simple_config import SimpleConfig
from .i18n import _ from .i18n import _
@ -117,18 +117,18 @@ def filter_noonion(servers):
return {k: v for k, v in servers.items() if not k.endswith('.onion')} return {k: v for k, v in servers.items() if not k.endswith('.onion')}
def filter_protocol(hostmap, protocol='s'): def filter_protocol(hostmap, protocol='s') -> Sequence[ServerAddr]:
'''Filters the hostmap for those implementing protocol. """Filters the hostmap for those implementing protocol."""
The result is a list in serialized form.'''
eligible = [] eligible = []
for host, portmap in hostmap.items(): for host, portmap in hostmap.items():
port = portmap.get(protocol) port = portmap.get(protocol)
if port: if port:
eligible.append(serialize_server(host, port, protocol)) eligible.append(ServerAddr(host, port, protocol=protocol))
return eligible return eligible
def pick_random_server(hostmap=None, protocol='s', exclude_set=None): def pick_random_server(hostmap=None, *, protocol='s',
exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]:
if hostmap is None: if hostmap is None:
hostmap = constants.net.DEFAULT_SERVERS hostmap = constants.net.DEFAULT_SERVERS
if exclude_set is None: if exclude_set is None:
@ -240,6 +240,14 @@ class Network(Logger):
LOGGING_SHORTCUT = 'n' LOGGING_SHORTCUT = 'n'
taskgroup: Optional[TaskGroup]
interface: Optional[Interface]
interfaces: Dict[ServerAddr, Interface]
connecting: Set[ServerAddr]
server_queue: 'Optional[queue.Queue[ServerAddr]]'
disconnected_servers: Set[ServerAddr]
default_server: ServerAddr
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None): def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
global _INSTANCE global _INSTANCE
assert _INSTANCE is None, "Network is a singleton!" assert _INSTANCE is None, "Network is a singleton!"
@ -266,14 +274,15 @@ class Network(Logger):
# Sanitize default server # Sanitize default server
if self.default_server: if self.default_server:
try: try:
deserialize_server(self.default_server) self.default_server = ServerAddr.from_str(self.default_server)
except: except:
self.logger.warning('failed to parse server-string; falling back to localhost.') self.logger.warning('failed to parse server-string; falling back to localhost.')
self.default_server = "localhost:50002:s" self.default_server = ServerAddr.from_str("localhost:50002:s")
if not self.default_server: else:
self.default_server = pick_random_server() self.default_server = pick_random_server()
assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}"
self.taskgroup = None # type: TaskGroup self.taskgroup = None
# locks # locks
self.restart_lock = asyncio.Lock() self.restart_lock = asyncio.Lock()
@ -295,10 +304,10 @@ class Network(Logger):
self.server_retry_time = time.time() self.server_retry_time = time.time()
self.nodes_retry_time = time.time() self.nodes_retry_time = time.time()
# the main server we are currently communicating with # the main server we are currently communicating with
self.interface = None # type: Optional[Interface] self.interface = None
self.default_server_changed_event = asyncio.Event() self.default_server_changed_event = asyncio.Event()
# set of servers we have an ongoing connection with # set of servers we have an ongoing connection with
self.interfaces = {} # type: Dict[str, Interface] self.interfaces = {}
self.auto_connect = self.config.get('auto_connect', True) self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set() self.connecting = set()
self.server_queue = None self.server_queue = None
@ -347,14 +356,15 @@ class Network(Logger):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return func_wrapper return func_wrapper
def _read_recent_servers(self): def _read_recent_servers(self) -> List[ServerAddr]:
if not self.config.path: if not self.config.path:
return [] return []
path = os.path.join(self.config.path, "recent_servers") path = os.path.join(self.config.path, "recent_servers")
try: try:
with open(path, "r", encoding='utf-8') as f: with open(path, "r", encoding='utf-8') as f:
data = f.read() data = f.read()
return json.loads(data) servers_list = json.loads(data)
return [ServerAddr.from_str(s) for s in servers_list]
except: except:
return [] return []
@ -363,7 +373,7 @@ class Network(Logger):
if not self.config.path: if not self.config.path:
return return
path = os.path.join(self.config.path, "recent_servers") path = os.path.join(self.config.path, "recent_servers")
s = json.dumps(self.recent_servers, indent=4, sort_keys=True) s = json.dumps(self.recent_servers, indent=4, sort_keys=True, cls=MyEncoder)
try: try:
with open(path, "w", encoding='utf-8') as f: with open(path, "w", encoding='utf-8') as f:
f.write(s) f.write(s)
@ -462,10 +472,10 @@ class Network(Logger):
util.trigger_callback(key, self.get_status_value(key)) util.trigger_callback(key, self.get_status_value(key))
def get_parameters(self) -> NetworkParameters: def get_parameters(self) -> NetworkParameters:
host, port, protocol = deserialize_server(self.default_server) server = self.default_server
return NetworkParameters(host=host, return NetworkParameters(host=server.host,
port=port, port=str(server.port),
protocol=protocol, protocol=server.protocol,
proxy=self.proxy, proxy=self.proxy,
auto_connect=self.auto_connect, auto_connect=self.auto_connect,
oneserver=self.oneserver) oneserver=self.oneserver)
@ -474,7 +484,7 @@ class Network(Logger):
if self.is_connected(): if self.is_connected():
return self.donation_address return self.donation_address
def get_interfaces(self) -> List[str]: def get_interfaces(self) -> List[ServerAddr]:
"""The list of servers for the connected interfaces.""" """The list of servers for the connected interfaces."""
with self.interfaces_lock: with self.interfaces_lock:
return list(self.interfaces) return list(self.interfaces)
@ -516,21 +526,18 @@ class Network(Logger):
# hardcoded servers # hardcoded servers
out.update(constants.net.DEFAULT_SERVERS) out.update(constants.net.DEFAULT_SERVERS)
# add recent servers # add recent servers
for s in self.recent_servers: for server in self.recent_servers:
try: port = str(server.port)
host, port, protocol = deserialize_server(s) if server.host in out:
except: out[server.host].update({server.protocol: port})
continue
if host in out:
out[host].update({protocol: port})
else: else:
out[host] = {protocol: port} out[server.host] = {server.protocol: port}
# potentially filter out some # potentially filter out some
if self.config.get('noonion'): if self.config.get('noonion'):
out = filter_noonion(out) out = filter_noonion(out)
return out return out
def _start_interface(self, server: str): def _start_interface(self, server: ServerAddr):
if server not in self.interfaces and server not in self.connecting: if server not in self.interfaces and server not in self.connecting:
if server == self.default_server: if server == self.default_server:
self.logger.info(f"connecting to {server} as new interface") self.logger.info(f"connecting to {server} as new interface")
@ -538,10 +545,10 @@ class Network(Logger):
self.connecting.add(server) self.connecting.add(server)
self.server_queue.put(server) self.server_queue.put(server)
def _start_random_interface(self): def _start_random_interface(self) -> Optional[ServerAddr]:
with self.interfaces_lock: with self.interfaces_lock:
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
server = pick_random_server(self.get_servers(), self.protocol, exclude_set) server = pick_random_server(self.get_servers(), protocol=self.protocol, exclude_set=exclude_set)
if server: if server:
self._start_interface(server) self._start_interface(server)
return server return server
@ -557,10 +564,9 @@ class Network(Logger):
proxy = net_params.proxy proxy = net_params.proxy
proxy_str = serialize_proxy(proxy) proxy_str = serialize_proxy(proxy)
host, port, protocol = net_params.host, net_params.port, net_params.protocol host, port, protocol = net_params.host, net_params.port, net_params.protocol
server_str = serialize_server(host, port, protocol)
# sanitize parameters # sanitize parameters
try: try:
deserialize_server(serialize_server(host, port, protocol)) server = ServerAddr(host, port, protocol=protocol)
if proxy: if proxy:
proxy_modes.index(proxy['mode']) + 1 proxy_modes.index(proxy['mode']) + 1
int(proxy['port']) int(proxy['port'])
@ -569,9 +575,9 @@ class Network(Logger):
self.config.set_key('auto_connect', net_params.auto_connect, False) self.config.set_key('auto_connect', net_params.auto_connect, False)
self.config.set_key('oneserver', net_params.oneserver, False) self.config.set_key('oneserver', net_params.oneserver, False)
self.config.set_key('proxy', proxy_str, False) self.config.set_key('proxy', proxy_str, False)
self.config.set_key('server', server_str, True) self.config.set_key('server', str(server), True)
# abort if changes were not allowed by config # abort if changes were not allowed by config
if self.config.get('server') != server_str \ if self.config.get('server') != str(server) \
or self.config.get('proxy') != proxy_str \ or self.config.get('proxy') != proxy_str \
or self.config.get('oneserver') != net_params.oneserver: or self.config.get('oneserver') != net_params.oneserver:
return return
@ -581,10 +587,10 @@ class Network(Logger):
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver: if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver:
# Restart the network defaulting to the given server # Restart the network defaulting to the given server
await self._stop() await self._stop()
self.default_server = server_str self.default_server = server
await self._start() await self._start()
elif self.default_server != server_str: elif self.default_server != server:
await self.switch_to_interface(server_str) await self.switch_to_interface(server)
else: else:
await self.switch_lagging_interface() await self.switch_lagging_interface()
@ -646,7 +652,7 @@ class Network(Logger):
# FIXME switch to best available? # FIXME switch to best available?
self.logger.info("tried to switch to best chain but no interfaces are on it") self.logger.info("tried to switch to best chain but no interfaces are on it")
async def switch_to_interface(self, server: str): async def switch_to_interface(self, server: ServerAddr):
"""Switch to server as our main interface. If no connection exists, """Switch to server as our main interface. If no connection exists,
queue interface to be started. The actual switch will queue interface to be started. The actual switch will
happen when the interface becomes ready. happen when the interface becomes ready.
@ -722,8 +728,8 @@ class Network(Logger):
@ignore_exceptions # do not kill main_taskgroup @ignore_exceptions # do not kill main_taskgroup
@log_exceptions @log_exceptions
async def _run_new_interface(self, server): async def _run_new_interface(self, server: ServerAddr):
interface = Interface(self, server, self.proxy) interface = Interface(network=self, server=server, proxy=self.proxy)
# note: using longer timeouts here as DNS can sometimes be slow! # note: using longer timeouts here as DNS can sometimes be slow!
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic) timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
try: try:
@ -1070,23 +1076,26 @@ class Network(Logger):
with self.interfaces_lock: interfaces = list(self.interfaces.values()) with self.interfaces_lock: interfaces = list(self.interfaces.values())
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces)) interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
if len(interfaces_on_selected_chain) == 0: return if len(interfaces_on_selected_chain) == 0: return
chosen_iface = random.choice(interfaces_on_selected_chain) chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface
# switch to server (and save to config) # switch to server (and save to config)
net_params = self.get_parameters() net_params = self.get_parameters()
host, port, protocol = deserialize_server(chosen_iface.server) server = chosen_iface.server
net_params = net_params._replace(host=host, port=port, protocol=protocol) net_params = net_params._replace(host=server.host,
port=str(server.port),
protocol=server.protocol)
await self.set_parameters(net_params) await self.set_parameters(net_params)
async def follow_chain_given_server(self, server_str: str) -> None: async def follow_chain_given_server(self, server: ServerAddr) -> None:
# note that server_str should correspond to a connected interface # note that server_str should correspond to a connected interface
iface = self.interfaces.get(server_str) iface = self.interfaces.get(server)
if iface is None: if iface is None:
return return
self._set_preferred_chain(iface.blockchain) self._set_preferred_chain(iface.blockchain)
# switch to server (and save to config) # switch to server (and save to config)
net_params = self.get_parameters() net_params = self.get_parameters()
host, port, protocol = deserialize_server(server_str) net_params = net_params._replace(host=server.host,
net_params = net_params._replace(host=host, port=port, protocol=protocol) port=str(server.port),
protocol=server.protocol)
await self.set_parameters(net_params) await self.set_parameters(net_params)
def get_local_height(self): def get_local_height(self):
@ -1107,7 +1116,7 @@ class Network(Logger):
assert not self.connecting and not self.server_queue assert not self.connecting and not self.server_queue
self.logger.info('starting network') self.logger.info('starting network')
self.disconnected_servers = set([]) self.disconnected_servers = set([])
self.protocol = deserialize_server(self.default_server)[2] self.protocol = self.default_server.protocol
self.server_queue = queue.Queue() self.server_queue = queue.Queue()
self._set_proxy(deserialize_proxy(self.config.get('proxy'))) self._set_proxy(deserialize_proxy(self.config.get('proxy')))
self._set_oneserver(self.config.get('oneserver', False)) self._set_oneserver(self.config.get('oneserver', False))
@ -1147,9 +1156,9 @@ class Network(Logger):
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2) await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError) as e: except (asyncio.TimeoutError, asyncio.CancelledError) as e:
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}") self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
self.taskgroup = None # type: TaskGroup self.taskgroup = None
self.interface = None # type: Interface self.interface = None
self.interfaces = {} # type: Dict[str, Interface] self.interfaces = {}
self.connecting.clear() self.connecting.clear()
self.server_queue = None self.server_queue = None
if not full_shutdown: if not full_shutdown:
@ -1268,8 +1277,8 @@ class Network(Logger):
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence): async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
responses = dict() responses = dict()
async def get_response(server): async def get_response(server: ServerAddr):
interface = Interface(self, server, self.proxy) interface = Interface(network=self, server=server, proxy=self.proxy)
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent) timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
try: try:
await asyncio.wait_for(interface.ready, timeout) await asyncio.wait_for(interface.ready, timeout)
@ -1283,5 +1292,6 @@ class Network(Logger):
responses[interface.server] = res responses[interface.server] = res
async with TaskGroup() as group: async with TaskGroup() as group:
for server in servers: for server in servers:
server = ServerAddr.from_str(server)
await group.spawn(get_response(server)) await group.spawn(get_response(server))
return responses return responses

Loading…
Cancel
Save