Browse Source

network: replace "server" strings with ServerAddr objects

patch-3
SomberNight 5 years ago
parent
commit
cf1f2ba4dc
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/daemon.py
  2. 2
      electrum/exchange_rate.py
  3. 40
      electrum/gui/qt/network_dialog.py
  4. 24
      electrum/gui/text.py
  5. 83
      electrum/interface.py
  6. 124
      electrum/network.py

2
electrum/daemon.py

@ -270,6 +270,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError):
class Daemon(Logger):
network: Optional[Network]
@profiler
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
Logger.__init__(self)

2
electrum/exchange_rate.py

@ -453,7 +453,7 @@ def get_exchanges_by_ccy(history=True):
class FxThread(ThreadJob):
def __init__(self, config: SimpleConfig, network: Network):
def __init__(self, config: SimpleConfig, network: Optional[Network]):
ThreadJob.__init__(self)
self.config = config
self.network = network

40
electrum/gui/qt/network_dialog.py

@ -36,7 +36,7 @@ from PyQt5.QtGui import QFontMetrics
from electrum.i18n import _
from electrum import constants, blockchain, util
from electrum.interface import serialize_server, deserialize_server
from electrum.interface import ServerAddr
from electrum.network import Network
from electrum.logging import get_logger
@ -72,10 +72,13 @@ class NetworkDialog(QDialog):
class NodesListWidget(QTreeWidget):
SERVER_ADDR_ROLE = Qt.UserRole + 100
CHAIN_ID_ROLE = Qt.UserRole + 101
IS_SERVER_ROLE = Qt.UserRole + 102
def __init__(self, parent):
QTreeWidget.__init__(self)
self.parent = parent
self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Connected node'), _('Height')])
self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu)
@ -84,13 +87,13 @@ class NodesListWidget(QTreeWidget):
item = self.currentItem()
if not item:
return
is_server = not bool(item.data(0, Qt.UserRole))
is_server = bool(item.data(0, self.IS_SERVER_ROLE))
menu = QMenu()
if is_server:
server = item.data(1, Qt.UserRole)
server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
else:
chain_id = item.data(1, Qt.UserRole)
chain_id = item.data(0, self.CHAIN_ID_ROLE)
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
menu.exec_(self.viewport().mapToGlobal(position))
@ -117,15 +120,15 @@ class NodesListWidget(QTreeWidget):
name = b.get_name()
if n_chains > 1:
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
x.setData(0, Qt.UserRole, 1)
x.setData(1, Qt.UserRole, b.get_id())
x.setData(0, self.IS_SERVER_ROLE, 0)
x.setData(0, self.CHAIN_ID_ROLE, b.get_id())
else:
x = self
for i in interfaces:
star = ' *' if i == network.interface else ''
item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
item.setData(0, Qt.UserRole, 0)
item.setData(1, Qt.UserRole, i.server)
item.setData(0, self.IS_SERVER_ROLE, 1)
item.setData(0, self.SERVER_ADDR_ROLE, i.server)
x.addChild(item)
if n_chains > 1:
self.addTopLevelItem(x)
@ -144,11 +147,11 @@ class ServerListWidget(QTreeWidget):
HOST = 0
PORT = 1
SERVER_STR_ROLE = Qt.UserRole + 100
SERVER_ADDR_ROLE = Qt.UserRole + 100
def __init__(self, parent):
QTreeWidget.__init__(self)
self.parent = parent
self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Host'), _('Port')])
self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu)
@ -158,14 +161,13 @@ class ServerListWidget(QTreeWidget):
if not item:
return
menu = QMenu()
server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE)
server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE)
menu.addAction(_("Use as server"), lambda: self.set_server(server))
menu.exec_(self.viewport().mapToGlobal(position))
def set_server(self, s):
host, port, protocol = deserialize_server(s)
self.parent.server_host.setText(host)
self.parent.server_port.setText(port)
def set_server(self, server: ServerAddr):
self.parent.server_host.setText(server.host)
self.parent.server_port.setText(str(server.port))
self.parent.set_server()
def keyPressEvent(self, event):
@ -188,8 +190,8 @@ class ServerListWidget(QTreeWidget):
port = d.get(protocol)
if port:
x = QTreeWidgetItem([_host, port])
server = serialize_server(_host, port, protocol)
x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server)
server = ServerAddr(_host, port, protocol=protocol)
x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server)
self.addTopLevelItem(x)
h = self.header()
@ -431,7 +433,7 @@ class NetworkChoiceLayout(object):
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
self.update()
def follow_server(self, server):
def follow_server(self, server: ServerAddr):
self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
self.update()

24
electrum/gui/text.py

@ -6,6 +6,7 @@ import locale
from decimal import Decimal
import getpass
import logging
from typing import TYPE_CHECKING
import electrum
from electrum import util
@ -15,15 +16,21 @@ from electrum.transaction import PartialTxOutput
from electrum.wallet import Wallet
from electrum.storage import WalletStorage
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
from electrum.interface import deserialize_server
from electrum.interface import ServerAddr
from electrum.logging import console_stderr_handler
if TYPE_CHECKING:
from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig
from electrum.plugin import Plugins
_ = lambda x:x # i18n
class ElectrumGui:
def __init__(self, config, daemon, plugins):
def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'):
self.config = config
self.network = daemon.network
@ -404,21 +411,24 @@ class ElectrumGui:
net_params = self.network.get_parameters()
host, port, protocol = net_params.host, net_params.port, net_params.protocol
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
srv = 'auto-connect' if auto_connect else self.network.default_server
srv = 'auto-connect' if auto_connect else str(self.network.default_server)
out = self.run_dialog('Network', [
{'label':'server', 'type':'str', 'value':srv},
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
], buttons = 1)
if out:
if out.get('server'):
server = out.get('server')
auto_connect = server == 'auto-connect'
server_str = out.get('server')
auto_connect = server_str == 'auto-connect'
if not auto_connect:
try:
host, port, protocol = deserialize_server(server)
server_addr = ServerAddr.from_str(server_str)
except Exception:
self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"")
self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"")
return False
host = server_addr.host
port = str(server_addr.port)
protocol = server_addr.protocol
if out.get('server') or out.get('proxy'):
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)

83
electrum/interface.py

@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
import itertools
@ -198,22 +198,57 @@ class _RSClient(RSClient):
raise ConnectError(e) from e
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(server_str).rsplit(':', 2)
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
if protocol not in ('s', 't'):
raise ValueError('invalid network protocol: {}'.format(protocol))
net_addr = NetAddress(host, port) # this validates host and port
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
return host, port, protocol
class ServerAddr:
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
assert isinstance(host, str), repr(host)
if protocol is None:
protocol = 's'
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
try:
net_addr = NetAddress(host, port) # this validates host and port
except Exception as e:
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
if protocol not in ('s', 't'):
raise ValueError(f"invalid network protocol: {protocol}")
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
self.port = int(net_addr.port)
self.protocol = protocol
self._net_addr_str = str(net_addr)
@classmethod
def from_str(cls, s: str) -> 'ServerAddr':
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(s).rsplit(':', 2)
return ServerAddr(host=host, port=port, protocol=protocol)
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str:
return str(':'.join([host, str(port), protocol]))
def __str__(self):
return '{}:{}'.format(self.net_addr_str(), self.protocol)
def to_json(self) -> str:
return str(self)
def __repr__(self):
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
def net_addr_str(self) -> str:
return self._net_addr_str
def __eq__(self, other):
if not isinstance(other, ServerAddr):
return False
return (self.host == other.host
and self.port == other.port
and self.protocol == other.protocol)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.host, self.port, self.protocol))
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
@ -232,12 +267,10 @@ class Interface(Logger):
LOGGING_SHORTCUT = 'i'
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]):
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
self.ready = asyncio.Future()
self.got_disconnected = asyncio.Future()
self.server = server
self.host, self.port, self.protocol = deserialize_server(self.server)
self.port = int(self.port)
Logger.__init__(self)
assert network.config.path
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
@ -259,8 +292,20 @@ class Interface(Logger):
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
self.taskgroup = SilentTaskGroup()
@property
def host(self):
return self.server.host
@property
def port(self):
return self.server.port
@property
def protocol(self):
return self.server.protocol
def diagnostic_name(self):
return str(NetAddress(self.host, self.port))
return self.server.net_addr_str()
def __str__(self):
return f"<Interface {self.diagnostic_name()}>"

124
electrum/network.py

@ -32,7 +32,7 @@ import socket
import json
import sys
import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
import traceback
import concurrent
from concurrent import futures
@ -44,7 +44,7 @@ from aiohttp import ClientResponse
from . import util
from .util import (log_exceptions, ignore_exceptions,
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
is_hash256_str, is_non_negative_integer)
is_hash256_str, is_non_negative_integer, MyEncoder)
from .bitcoin import COIN
from . import constants
@ -53,9 +53,9 @@ from . import bitcoin
from . import dns_hacks
from .transaction import Transaction
from .blockchain import Blockchain, HEADER_SIZE
from .interface import (Interface, serialize_server, deserialize_server,
from .interface import (Interface,
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
NetworkException, RequestCorrupted)
NetworkException, RequestCorrupted, ServerAddr)
from .version import PROTOCOL_VERSION
from .simple_config import SimpleConfig
from .i18n import _
@ -117,18 +117,18 @@ def filter_noonion(servers):
return {k: v for k, v in servers.items() if not k.endswith('.onion')}
def filter_protocol(hostmap, protocol='s'):
'''Filters the hostmap for those implementing protocol.
The result is a list in serialized form.'''
def filter_protocol(hostmap, protocol='s') -> Sequence[ServerAddr]:
"""Filters the hostmap for those implementing protocol."""
eligible = []
for host, portmap in hostmap.items():
port = portmap.get(protocol)
if port:
eligible.append(serialize_server(host, port, protocol))
eligible.append(ServerAddr(host, port, protocol=protocol))
return eligible
def pick_random_server(hostmap=None, protocol='s', exclude_set=None):
def pick_random_server(hostmap=None, *, protocol='s',
exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]:
if hostmap is None:
hostmap = constants.net.DEFAULT_SERVERS
if exclude_set is None:
@ -240,6 +240,14 @@ class Network(Logger):
LOGGING_SHORTCUT = 'n'
taskgroup: Optional[TaskGroup]
interface: Optional[Interface]
interfaces: Dict[ServerAddr, Interface]
connecting: Set[ServerAddr]
server_queue: 'Optional[queue.Queue[ServerAddr]]'
disconnected_servers: Set[ServerAddr]
default_server: ServerAddr
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
global _INSTANCE
assert _INSTANCE is None, "Network is a singleton!"
@ -266,14 +274,15 @@ class Network(Logger):
# Sanitize default server
if self.default_server:
try:
deserialize_server(self.default_server)
self.default_server = ServerAddr.from_str(self.default_server)
except:
self.logger.warning('failed to parse server-string; falling back to localhost.')
self.default_server = "localhost:50002:s"
if not self.default_server:
self.default_server = ServerAddr.from_str("localhost:50002:s")
else:
self.default_server = pick_random_server()
assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}"
self.taskgroup = None # type: TaskGroup
self.taskgroup = None
# locks
self.restart_lock = asyncio.Lock()
@ -295,10 +304,10 @@ class Network(Logger):
self.server_retry_time = time.time()
self.nodes_retry_time = time.time()
# the main server we are currently communicating with
self.interface = None # type: Optional[Interface]
self.interface = None
self.default_server_changed_event = asyncio.Event()
# set of servers we have an ongoing connection with
self.interfaces = {} # type: Dict[str, Interface]
self.interfaces = {}
self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set()
self.server_queue = None
@ -347,14 +356,15 @@ class Network(Logger):
return func(self, *args, **kwargs)
return func_wrapper
def _read_recent_servers(self):
def _read_recent_servers(self) -> List[ServerAddr]:
if not self.config.path:
return []
path = os.path.join(self.config.path, "recent_servers")
try:
with open(path, "r", encoding='utf-8') as f:
data = f.read()
return json.loads(data)
servers_list = json.loads(data)
return [ServerAddr.from_str(s) for s in servers_list]
except:
return []
@ -363,7 +373,7 @@ class Network(Logger):
if not self.config.path:
return
path = os.path.join(self.config.path, "recent_servers")
s = json.dumps(self.recent_servers, indent=4, sort_keys=True)
s = json.dumps(self.recent_servers, indent=4, sort_keys=True, cls=MyEncoder)
try:
with open(path, "w", encoding='utf-8') as f:
f.write(s)
@ -462,10 +472,10 @@ class Network(Logger):
util.trigger_callback(key, self.get_status_value(key))
def get_parameters(self) -> NetworkParameters:
host, port, protocol = deserialize_server(self.default_server)
return NetworkParameters(host=host,
port=port,
protocol=protocol,
server = self.default_server
return NetworkParameters(host=server.host,
port=str(server.port),
protocol=server.protocol,
proxy=self.proxy,
auto_connect=self.auto_connect,
oneserver=self.oneserver)
@ -474,7 +484,7 @@ class Network(Logger):
if self.is_connected():
return self.donation_address
def get_interfaces(self) -> List[str]:
def get_interfaces(self) -> List[ServerAddr]:
"""The list of servers for the connected interfaces."""
with self.interfaces_lock:
return list(self.interfaces)
@ -516,21 +526,18 @@ class Network(Logger):
# hardcoded servers
out.update(constants.net.DEFAULT_SERVERS)
# add recent servers
for s in self.recent_servers:
try:
host, port, protocol = deserialize_server(s)
except:
continue
if host in out:
out[host].update({protocol: port})
for server in self.recent_servers:
port = str(server.port)
if server.host in out:
out[server.host].update({server.protocol: port})
else:
out[host] = {protocol: port}
out[server.host] = {server.protocol: port}
# potentially filter out some
if self.config.get('noonion'):
out = filter_noonion(out)
return out
def _start_interface(self, server: str):
def _start_interface(self, server: ServerAddr):
if server not in self.interfaces and server not in self.connecting:
if server == self.default_server:
self.logger.info(f"connecting to {server} as new interface")
@ -538,10 +545,10 @@ class Network(Logger):
self.connecting.add(server)
self.server_queue.put(server)
def _start_random_interface(self):
def _start_random_interface(self) -> Optional[ServerAddr]:
with self.interfaces_lock:
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
server = pick_random_server(self.get_servers(), protocol=self.protocol, exclude_set=exclude_set)
if server:
self._start_interface(server)
return server
@ -557,10 +564,9 @@ class Network(Logger):
proxy = net_params.proxy
proxy_str = serialize_proxy(proxy)
host, port, protocol = net_params.host, net_params.port, net_params.protocol
server_str = serialize_server(host, port, protocol)
# sanitize parameters
try:
deserialize_server(serialize_server(host, port, protocol))
server = ServerAddr(host, port, protocol=protocol)
if proxy:
proxy_modes.index(proxy['mode']) + 1
int(proxy['port'])
@ -569,9 +575,9 @@ class Network(Logger):
self.config.set_key('auto_connect', net_params.auto_connect, False)
self.config.set_key('oneserver', net_params.oneserver, False)
self.config.set_key('proxy', proxy_str, False)
self.config.set_key('server', server_str, True)
self.config.set_key('server', str(server), True)
# abort if changes were not allowed by config
if self.config.get('server') != server_str \
if self.config.get('server') != str(server) \
or self.config.get('proxy') != proxy_str \
or self.config.get('oneserver') != net_params.oneserver:
return
@ -581,10 +587,10 @@ class Network(Logger):
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver:
# Restart the network defaulting to the given server
await self._stop()
self.default_server = server_str
self.default_server = server
await self._start()
elif self.default_server != server_str:
await self.switch_to_interface(server_str)
elif self.default_server != server:
await self.switch_to_interface(server)
else:
await self.switch_lagging_interface()
@ -646,7 +652,7 @@ class Network(Logger):
# FIXME switch to best available?
self.logger.info("tried to switch to best chain but no interfaces are on it")
async def switch_to_interface(self, server: str):
async def switch_to_interface(self, server: ServerAddr):
"""Switch to server as our main interface. If no connection exists,
queue interface to be started. The actual switch will
happen when the interface becomes ready.
@ -722,8 +728,8 @@ class Network(Logger):
@ignore_exceptions # do not kill main_taskgroup
@log_exceptions
async def _run_new_interface(self, server):
interface = Interface(self, server, self.proxy)
async def _run_new_interface(self, server: ServerAddr):
interface = Interface(network=self, server=server, proxy=self.proxy)
# note: using longer timeouts here as DNS can sometimes be slow!
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
try:
@ -1070,23 +1076,26 @@ class Network(Logger):
with self.interfaces_lock: interfaces = list(self.interfaces.values())
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
if len(interfaces_on_selected_chain) == 0: return
chosen_iface = random.choice(interfaces_on_selected_chain)
chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface
# switch to server (and save to config)
net_params = self.get_parameters()
host, port, protocol = deserialize_server(chosen_iface.server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
server = chosen_iface.server
net_params = net_params._replace(host=server.host,
port=str(server.port),
protocol=server.protocol)
await self.set_parameters(net_params)
async def follow_chain_given_server(self, server_str: str) -> None:
async def follow_chain_given_server(self, server: ServerAddr) -> None:
# note that server_str should correspond to a connected interface
iface = self.interfaces.get(server_str)
iface = self.interfaces.get(server)
if iface is None:
return
self._set_preferred_chain(iface.blockchain)
# switch to server (and save to config)
net_params = self.get_parameters()
host, port, protocol = deserialize_server(server_str)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
net_params = net_params._replace(host=server.host,
port=str(server.port),
protocol=server.protocol)
await self.set_parameters(net_params)
def get_local_height(self):
@ -1107,7 +1116,7 @@ class Network(Logger):
assert not self.connecting and not self.server_queue
self.logger.info('starting network')
self.disconnected_servers = set([])
self.protocol = deserialize_server(self.default_server)[2]
self.protocol = self.default_server.protocol
self.server_queue = queue.Queue()
self._set_proxy(deserialize_proxy(self.config.get('proxy')))
self._set_oneserver(self.config.get('oneserver', False))
@ -1147,9 +1156,9 @@ class Network(Logger):
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
self.taskgroup = None # type: TaskGroup
self.interface = None # type: Interface
self.interfaces = {} # type: Dict[str, Interface]
self.taskgroup = None
self.interface = None
self.interfaces = {}
self.connecting.clear()
self.server_queue = None
if not full_shutdown:
@ -1268,8 +1277,8 @@ class Network(Logger):
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
responses = dict()
async def get_response(server):
interface = Interface(self, server, self.proxy)
async def get_response(server: ServerAddr):
interface = Interface(network=self, server=server, proxy=self.proxy)
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
try:
await asyncio.wait_for(interface.ready, timeout)
@ -1283,5 +1292,6 @@ class Network(Logger):
responses[interface.server] = res
async with TaskGroup() as group:
for server in servers:
server = ServerAddr.from_str(server)
await group.spawn(get_response(server))
return responses

Loading…
Cancel
Save