From 2eec7e16004449d70a081e80f782fdc6c90f6b02 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Sun, 21 Jun 2020 11:31:54 +0200 Subject: [PATCH] network: smarter switch_unwanted_fork_interface Previously this function would not switch to a different chain if the current chain contained the preferred block. This was not the intended behaviour: if there is a *stronger* chain that *also* contains the preferred block, we should jump to that. Note that with this commit there will now always be a preferred block (defaults to genesis). Previously, it might seem that often there was none, but actually in practice if the user used the GUI context menu to switch servers even once, there was one (usually genesis). Hence, with the old code, if an attacker mined a single header which then got reorged, auto_connect clients which were connected to the attacker's server would never switch servers (jump chains) even without the user explicitly configuring preference for the stale branch. --- electrum/blockchain.py | 13 +++++++ electrum/network.py | 64 +++++++++++++++---------------- electrum/tests/test_blockchain.py | 48 +++++++++++++++++++++++ 3 files changed, 92 insertions(+), 33 deletions(-) diff --git a/electrum/blockchain.py b/electrum/blockchain.py index 15884c511..6dfe3573e 100644 --- a/electrum/blockchain.py +++ b/electrum/blockchain.py @@ -646,6 +646,7 @@ class Blockchain(Logger): def check_header(header: dict) -> Optional[Blockchain]: + """Returns any Blockchain that contains header, or None.""" if type(header) is not dict: return None with blockchains_lock: chains = list(blockchains.values()) @@ -656,8 +657,20 @@ def check_header(header: dict) -> Optional[Blockchain]: def can_connect(header: dict) -> Optional[Blockchain]: + """Returns the Blockchain that has a tip that directly links up + with header, or None. + """ with blockchains_lock: chains = list(blockchains.values()) for b in chains: if b.can_connect(header): return b return None + + +def get_chains_that_contain_header(height: int, header_hash: str) -> Sequence[Blockchain]: + """Returns a list of Blockchains that contain header, best chain first.""" + with blockchains_lock: chains = list(blockchains.values()) + chains = [chain for chain in chains + if chain.check_hash(height=height, header_hash=header_hash)] + chains = sorted(chains, key=lambda x: x.get_chainwork(), reverse=True) + return chains diff --git a/electrum/network.py b/electrum/network.py index aa0d34bca..c38688e72 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -32,7 +32,7 @@ import socket import json import sys import asyncio -from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set +from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set, Any import traceback import concurrent from concurrent import futures @@ -276,7 +276,9 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): blockchain.read_blockchains(self.config) blockchain.init_headers_file_for_best_chain() self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}") - self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict] + self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Dict[str, Any] + if self._blockchain_preferred_block is None: + self._set_preferred_chain(None) self._blockchain = blockchain.get_best_chain() self._allowed_protocols = {PREFERRED_NETWORK_PROTOCOL} @@ -624,7 +626,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): await self.switch_to_interface(random.choice(servers)) async def switch_lagging_interface(self): - '''If auto_connect and lagging, switch interface''' + """If auto_connect and lagging, switch interface (only within fork).""" if self.auto_connect and await self._server_is_lagging(): # switch to one that has the correct header (not height) best_header = self.blockchain().header_at_tip() @@ -634,40 +636,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) - async def switch_unwanted_fork_interface(self): - """If auto_connect and main interface is not on preferred fork, - try to switch to preferred fork. - """ + async def switch_unwanted_fork_interface(self) -> None: + """If auto_connect, maybe switch to another fork/chain.""" if not self.auto_connect or not self.interface: return with self.interfaces_lock: interfaces = list(self.interfaces.values()) - # try to switch to preferred fork - if self._blockchain_preferred_block: - pref_height = self._blockchain_preferred_block['height'] - pref_hash = self._blockchain_preferred_block['hash'] - if self.interface.blockchain.check_hash(pref_height, pref_hash): - return # already on preferred fork - filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash), - interfaces)) + pref_height = self._blockchain_preferred_block['height'] + pref_hash = self._blockchain_preferred_block['hash'] + # shortcut for common case + if pref_height == 0: + return + # maybe try switching chains; starting with most desirable first + matching_chains = blockchain.get_chains_that_contain_header(pref_height, pref_hash) + chains_to_try = list(matching_chains) + [blockchain.get_best_chain()] + for rank, chain in enumerate(chains_to_try): + # check if main interface is already on this fork + if self.interface.blockchain == chain: + return + # switch to another random interface that is on this fork, if any + filtered = [iface for iface in interfaces + if iface.blockchain == chain] if filtered: - self.logger.info("switching to preferred fork") + self.logger.info(f"switching to (more) preferred fork (rank {rank})") chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) return - else: - self.logger.info("tried to switch to preferred fork but no interfaces are on it") - # try to switch to best chain - if self.blockchain().parent is None: - return # already on best chain - filtered = list(filter(lambda iface: iface.blockchain.parent is None, - interfaces)) - if filtered: - self.logger.info("switching to best chain") - chosen_iface = random.choice(filtered) - await self.switch_to_interface(chosen_iface.server) - else: - # FIXME switch to best available? - self.logger.info("tried to switch to best chain but no interfaces are on it") + self.logger.info("tried to switch to (more) preferred fork but no interfaces are on any") async def switch_to_interface(self, server: ServerAddr): """Switch to server as our main interface. If no connection exists, @@ -1083,9 +1077,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): out[chain_id] = r return out - def _set_preferred_chain(self, chain: Blockchain): - height = chain.get_max_forkpoint() - header_hash = chain.get_hash(height) + def _set_preferred_chain(self, chain: Optional[Blockchain]): + if chain: + height = chain.get_max_forkpoint() + header_hash = chain.get_hash(height) + else: + height = 0 + header_hash = constants.net.GENESIS self._blockchain_preferred_block = { 'height': height, 'hash': header_hash, diff --git a/electrum/tests/test_blockchain.py b/electrum/tests/test_blockchain.py index 206a471c9..17d0d836f 100644 --- a/electrum/tests/test_blockchain.py +++ b/electrum/tests/test_blockchain.py @@ -336,6 +336,54 @@ class TestBlockchain(ElectrumTestCase): for b in (chain_u, chain_l, chain_z): self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())])) + def get_chains_that_contain_header_helper(self, header: dict): + height = header['block_height'] + header_hash = hash_header(header) + return blockchain.get_chains_that_contain_header(height, header_hash) + + def test_get_chains_that_contain_header(self): + blockchain.blockchains[constants.net.GENESIS] = chain_u = Blockchain( + config=self.config, forkpoint=0, parent=None, + forkpoint_hash=constants.net.GENESIS, prev_hash=None) + open(chain_u.path(), 'w+').close() + self._append_header(chain_u, self.HEADERS['A']) + self._append_header(chain_u, self.HEADERS['B']) + self._append_header(chain_u, self.HEADERS['C']) + self._append_header(chain_u, self.HEADERS['D']) + self._append_header(chain_u, self.HEADERS['E']) + self._append_header(chain_u, self.HEADERS['F']) + self._append_header(chain_u, self.HEADERS['O']) + self._append_header(chain_u, self.HEADERS['P']) + self._append_header(chain_u, self.HEADERS['Q']) + + chain_l = chain_u.fork(self.HEADERS['G']) + self._append_header(chain_l, self.HEADERS['H']) + self._append_header(chain_l, self.HEADERS['I']) + self._append_header(chain_l, self.HEADERS['J']) + self._append_header(chain_l, self.HEADERS['K']) + self._append_header(chain_l, self.HEADERS['L']) + + chain_z = chain_l.fork(self.HEADERS['M']) + + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A'])) + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C'])) + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F'])) + self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['G'])) + self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['I'])) + self.assertEqual([chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['M'])) + self.assertEqual([chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['K'])) + + self._append_header(chain_z, self.HEADERS['N']) + self._append_header(chain_z, self.HEADERS['X']) + self._append_header(chain_z, self.HEADERS['Y']) + self._append_header(chain_z, self.HEADERS['Z']) + + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A'])) + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C'])) + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F'])) + self.assertEqual([chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['O'])) + self.assertEqual([chain_z, chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['I'])) + class TestVerifyHeader(ElectrumTestCase):