diff --git a/electrum/blockchain.py b/electrum/blockchain.py index 15884c511..6dfe3573e 100644 --- a/electrum/blockchain.py +++ b/electrum/blockchain.py @@ -646,6 +646,7 @@ class Blockchain(Logger): def check_header(header: dict) -> Optional[Blockchain]: + """Returns any Blockchain that contains header, or None.""" if type(header) is not dict: return None with blockchains_lock: chains = list(blockchains.values()) @@ -656,8 +657,20 @@ def check_header(header: dict) -> Optional[Blockchain]: def can_connect(header: dict) -> Optional[Blockchain]: + """Returns the Blockchain that has a tip that directly links up + with header, or None. + """ with blockchains_lock: chains = list(blockchains.values()) for b in chains: if b.can_connect(header): return b return None + + +def get_chains_that_contain_header(height: int, header_hash: str) -> Sequence[Blockchain]: + """Returns a list of Blockchains that contain header, best chain first.""" + with blockchains_lock: chains = list(blockchains.values()) + chains = [chain for chain in chains + if chain.check_hash(height=height, header_hash=header_hash)] + chains = sorted(chains, key=lambda x: x.get_chainwork(), reverse=True) + return chains diff --git a/electrum/network.py b/electrum/network.py index aa0d34bca..c38688e72 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -32,7 +32,7 @@ import socket import json import sys import asyncio -from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set +from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set, Any import traceback import concurrent from concurrent import futures @@ -276,7 +276,9 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): blockchain.read_blockchains(self.config) blockchain.init_headers_file_for_best_chain() self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}") - self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict] + self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Dict[str, Any] + if self._blockchain_preferred_block is None: + self._set_preferred_chain(None) self._blockchain = blockchain.get_best_chain() self._allowed_protocols = {PREFERRED_NETWORK_PROTOCOL} @@ -624,7 +626,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): await self.switch_to_interface(random.choice(servers)) async def switch_lagging_interface(self): - '''If auto_connect and lagging, switch interface''' + """If auto_connect and lagging, switch interface (only within fork).""" if self.auto_connect and await self._server_is_lagging(): # switch to one that has the correct header (not height) best_header = self.blockchain().header_at_tip() @@ -634,40 +636,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) - async def switch_unwanted_fork_interface(self): - """If auto_connect and main interface is not on preferred fork, - try to switch to preferred fork. - """ + async def switch_unwanted_fork_interface(self) -> None: + """If auto_connect, maybe switch to another fork/chain.""" if not self.auto_connect or not self.interface: return with self.interfaces_lock: interfaces = list(self.interfaces.values()) - # try to switch to preferred fork - if self._blockchain_preferred_block: - pref_height = self._blockchain_preferred_block['height'] - pref_hash = self._blockchain_preferred_block['hash'] - if self.interface.blockchain.check_hash(pref_height, pref_hash): - return # already on preferred fork - filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash), - interfaces)) + pref_height = self._blockchain_preferred_block['height'] + pref_hash = self._blockchain_preferred_block['hash'] + # shortcut for common case + if pref_height == 0: + return + # maybe try switching chains; starting with most desirable first + matching_chains = blockchain.get_chains_that_contain_header(pref_height, pref_hash) + chains_to_try = list(matching_chains) + [blockchain.get_best_chain()] + for rank, chain in enumerate(chains_to_try): + # check if main interface is already on this fork + if self.interface.blockchain == chain: + return + # switch to another random interface that is on this fork, if any + filtered = [iface for iface in interfaces + if iface.blockchain == chain] if filtered: - self.logger.info("switching to preferred fork") + self.logger.info(f"switching to (more) preferred fork (rank {rank})") chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) return - else: - self.logger.info("tried to switch to preferred fork but no interfaces are on it") - # try to switch to best chain - if self.blockchain().parent is None: - return # already on best chain - filtered = list(filter(lambda iface: iface.blockchain.parent is None, - interfaces)) - if filtered: - self.logger.info("switching to best chain") - chosen_iface = random.choice(filtered) - await self.switch_to_interface(chosen_iface.server) - else: - # FIXME switch to best available? - self.logger.info("tried to switch to best chain but no interfaces are on it") + self.logger.info("tried to switch to (more) preferred fork but no interfaces are on any") async def switch_to_interface(self, server: ServerAddr): """Switch to server as our main interface. If no connection exists, @@ -1083,9 +1077,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): out[chain_id] = r return out - def _set_preferred_chain(self, chain: Blockchain): - height = chain.get_max_forkpoint() - header_hash = chain.get_hash(height) + def _set_preferred_chain(self, chain: Optional[Blockchain]): + if chain: + height = chain.get_max_forkpoint() + header_hash = chain.get_hash(height) + else: + height = 0 + header_hash = constants.net.GENESIS self._blockchain_preferred_block = { 'height': height, 'hash': header_hash, diff --git a/electrum/tests/test_blockchain.py b/electrum/tests/test_blockchain.py index 206a471c9..17d0d836f 100644 --- a/electrum/tests/test_blockchain.py +++ b/electrum/tests/test_blockchain.py @@ -336,6 +336,54 @@ class TestBlockchain(ElectrumTestCase): for b in (chain_u, chain_l, chain_z): self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())])) + def get_chains_that_contain_header_helper(self, header: dict): + height = header['block_height'] + header_hash = hash_header(header) + return blockchain.get_chains_that_contain_header(height, header_hash) + + def test_get_chains_that_contain_header(self): + blockchain.blockchains[constants.net.GENESIS] = chain_u = Blockchain( + config=self.config, forkpoint=0, parent=None, + forkpoint_hash=constants.net.GENESIS, prev_hash=None) + open(chain_u.path(), 'w+').close() + self._append_header(chain_u, self.HEADERS['A']) + self._append_header(chain_u, self.HEADERS['B']) + self._append_header(chain_u, self.HEADERS['C']) + self._append_header(chain_u, self.HEADERS['D']) + self._append_header(chain_u, self.HEADERS['E']) + self._append_header(chain_u, self.HEADERS['F']) + self._append_header(chain_u, self.HEADERS['O']) + self._append_header(chain_u, self.HEADERS['P']) + self._append_header(chain_u, self.HEADERS['Q']) + + chain_l = chain_u.fork(self.HEADERS['G']) + self._append_header(chain_l, self.HEADERS['H']) + self._append_header(chain_l, self.HEADERS['I']) + self._append_header(chain_l, self.HEADERS['J']) + self._append_header(chain_l, self.HEADERS['K']) + self._append_header(chain_l, self.HEADERS['L']) + + chain_z = chain_l.fork(self.HEADERS['M']) + + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A'])) + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C'])) + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F'])) + self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['G'])) + self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['I'])) + self.assertEqual([chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['M'])) + self.assertEqual([chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['K'])) + + self._append_header(chain_z, self.HEADERS['N']) + self._append_header(chain_z, self.HEADERS['X']) + self._append_header(chain_z, self.HEADERS['Y']) + self._append_header(chain_z, self.HEADERS['Z']) + + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A'])) + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C'])) + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F'])) + self.assertEqual([chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['O'])) + self.assertEqual([chain_z, chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['I'])) + class TestVerifyHeader(ElectrumTestCase):