Browse Source

network: smarter switch_unwanted_fork_interface

Previously this function would not switch to a different chain if the
current chain contained the preferred block. This was not the intended
behaviour: if there is a *stronger* chain that *also* contains the
preferred block, we should jump to that.

Note that with this commit there will now always be a preferred block
(defaults to genesis). Previously, it might seem that often there was none,
but actually in practice if the user used the GUI context menu to switch
servers even once, there was one (usually genesis).

Hence, with the old code, if an attacker mined a single header which
then got reorged, auto_connect clients which were connected to the
attacker's server would never switch servers (jump chains) even
without the user explicitly configuring preference for the stale branch.
bip39-recovery
SomberNight 5 years ago
parent
commit
2eec7e1600
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 13
      electrum/blockchain.py
  2. 56
      electrum/network.py
  3. 48
      electrum/tests/test_blockchain.py

13
electrum/blockchain.py

@ -646,6 +646,7 @@ class Blockchain(Logger):
def check_header(header: dict) -> Optional[Blockchain]:
"""Returns any Blockchain that contains header, or None."""
if type(header) is not dict:
return None
with blockchains_lock: chains = list(blockchains.values())
@ -656,8 +657,20 @@ def check_header(header: dict) -> Optional[Blockchain]:
def can_connect(header: dict) -> Optional[Blockchain]:
"""Returns the Blockchain that has a tip that directly links up
with header, or None.
"""
with blockchains_lock: chains = list(blockchains.values())
for b in chains:
if b.can_connect(header):
return b
return None
def get_chains_that_contain_header(height: int, header_hash: str) -> Sequence[Blockchain]:
"""Returns a list of Blockchains that contain header, best chain first."""
with blockchains_lock: chains = list(blockchains.values())
chains = [chain for chain in chains
if chain.check_hash(height=height, header_hash=header_hash)]
chains = sorted(chains, key=lambda x: x.get_chainwork(), reverse=True)
return chains

56
electrum/network.py

@ -32,7 +32,7 @@ import socket
import json
import sys
import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set, Any
import traceback
import concurrent
from concurrent import futures
@ -276,7 +276,9 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
blockchain.read_blockchains(self.config)
blockchain.init_headers_file_for_best_chain()
self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}")
self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict]
self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Dict[str, Any]
if self._blockchain_preferred_block is None:
self._set_preferred_chain(None)
self._blockchain = blockchain.get_best_chain()
self._allowed_protocols = {PREFERRED_NETWORK_PROTOCOL}
@ -624,7 +626,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
await self.switch_to_interface(random.choice(servers))
async def switch_lagging_interface(self):
'''If auto_connect and lagging, switch interface'''
"""If auto_connect and lagging, switch interface (only within fork)."""
if self.auto_connect and await self._server_is_lagging():
# switch to one that has the correct header (not height)
best_header = self.blockchain().header_at_tip()
@ -634,40 +636,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
chosen_iface = random.choice(filtered)
await self.switch_to_interface(chosen_iface.server)
async def switch_unwanted_fork_interface(self):
"""If auto_connect and main interface is not on preferred fork,
try to switch to preferred fork.
"""
async def switch_unwanted_fork_interface(self) -> None:
"""If auto_connect, maybe switch to another fork/chain."""
if not self.auto_connect or not self.interface:
return
with self.interfaces_lock: interfaces = list(self.interfaces.values())
# try to switch to preferred fork
if self._blockchain_preferred_block:
pref_height = self._blockchain_preferred_block['height']
pref_hash = self._blockchain_preferred_block['hash']
if self.interface.blockchain.check_hash(pref_height, pref_hash):
return # already on preferred fork
filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash),
interfaces))
if filtered:
self.logger.info("switching to preferred fork")
chosen_iface = random.choice(filtered)
await self.switch_to_interface(chosen_iface.server)
# shortcut for common case
if pref_height == 0:
return
else:
self.logger.info("tried to switch to preferred fork but no interfaces are on it")
# try to switch to best chain
if self.blockchain().parent is None:
return # already on best chain
filtered = list(filter(lambda iface: iface.blockchain.parent is None,
interfaces))
# maybe try switching chains; starting with most desirable first
matching_chains = blockchain.get_chains_that_contain_header(pref_height, pref_hash)
chains_to_try = list(matching_chains) + [blockchain.get_best_chain()]
for rank, chain in enumerate(chains_to_try):
# check if main interface is already on this fork
if self.interface.blockchain == chain:
return
# switch to another random interface that is on this fork, if any
filtered = [iface for iface in interfaces
if iface.blockchain == chain]
if filtered:
self.logger.info("switching to best chain")
self.logger.info(f"switching to (more) preferred fork (rank {rank})")
chosen_iface = random.choice(filtered)
await self.switch_to_interface(chosen_iface.server)
else:
# FIXME switch to best available?
self.logger.info("tried to switch to best chain but no interfaces are on it")
return
self.logger.info("tried to switch to (more) preferred fork but no interfaces are on any")
async def switch_to_interface(self, server: ServerAddr):
"""Switch to server as our main interface. If no connection exists,
@ -1083,9 +1077,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
out[chain_id] = r
return out
def _set_preferred_chain(self, chain: Blockchain):
def _set_preferred_chain(self, chain: Optional[Blockchain]):
if chain:
height = chain.get_max_forkpoint()
header_hash = chain.get_hash(height)
else:
height = 0
header_hash = constants.net.GENESIS
self._blockchain_preferred_block = {
'height': height,
'hash': header_hash,

48
electrum/tests/test_blockchain.py

@ -336,6 +336,54 @@ class TestBlockchain(ElectrumTestCase):
for b in (chain_u, chain_l, chain_z):
self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())]))
def get_chains_that_contain_header_helper(self, header: dict):
height = header['block_height']
header_hash = hash_header(header)
return blockchain.get_chains_that_contain_header(height, header_hash)
def test_get_chains_that_contain_header(self):
blockchain.blockchains[constants.net.GENESIS] = chain_u = Blockchain(
config=self.config, forkpoint=0, parent=None,
forkpoint_hash=constants.net.GENESIS, prev_hash=None)
open(chain_u.path(), 'w+').close()
self._append_header(chain_u, self.HEADERS['A'])
self._append_header(chain_u, self.HEADERS['B'])
self._append_header(chain_u, self.HEADERS['C'])
self._append_header(chain_u, self.HEADERS['D'])
self._append_header(chain_u, self.HEADERS['E'])
self._append_header(chain_u, self.HEADERS['F'])
self._append_header(chain_u, self.HEADERS['O'])
self._append_header(chain_u, self.HEADERS['P'])
self._append_header(chain_u, self.HEADERS['Q'])
chain_l = chain_u.fork(self.HEADERS['G'])
self._append_header(chain_l, self.HEADERS['H'])
self._append_header(chain_l, self.HEADERS['I'])
self._append_header(chain_l, self.HEADERS['J'])
self._append_header(chain_l, self.HEADERS['K'])
self._append_header(chain_l, self.HEADERS['L'])
chain_z = chain_l.fork(self.HEADERS['M'])
self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A']))
self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C']))
self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F']))
self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['G']))
self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['I']))
self.assertEqual([chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['M']))
self.assertEqual([chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['K']))
self._append_header(chain_z, self.HEADERS['N'])
self._append_header(chain_z, self.HEADERS['X'])
self._append_header(chain_z, self.HEADERS['Y'])
self._append_header(chain_z, self.HEADERS['Z'])
self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A']))
self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C']))
self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F']))
self.assertEqual([chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['O']))
self.assertEqual([chain_z, chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['I']))
class TestVerifyHeader(ElectrumTestCase):

Loading…
Cancel
Save