diff --git a/electrum/interface.py b/electrum/interface.py index 4d6bd87af..b29fca558 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -503,11 +503,14 @@ class Interface(PrintError): # is assumed to be expensive; especially as forks below the max # checkpoint are ignored. self.print_error("new fork at bad height {}. conflict!!".format(bad)) + assert self.blockchain != branch ismocking = type(branch) is dict if ismocking: self.print_error("TODO replace blockchain") return 'fork_conflict', height self.print_error('forkpoint conflicts with existing fork', branch.path()) + self._raise_if_fork_conflicts_with_default_server(branch) + self._disconnect_from_interfaces_on_conflicting_blockchain(branch) branch.write(b'', 0) branch.save_header(bad_header) self.blockchain = branch @@ -524,6 +527,21 @@ class Interface(PrintError): assert b.forkpoint == bad return 'fork_noconflict', height + def _raise_if_fork_conflicts_with_default_server(self, chain_to_delete: Blockchain) -> None: + main_interface = self.network.interface + if not main_interface: return + if main_interface == self: return + chain_of_default_server = main_interface.blockchain + if not chain_of_default_server: return + if chain_to_delete == chain_of_default_server: + raise GracefulDisconnect('refusing to overwrite blockchain of default server') + + def _disconnect_from_interfaces_on_conflicting_blockchain(self, chain: Blockchain) -> None: + ifaces = self.network.disconnect_from_interfaces_on_given_blockchain(chain) + if not ifaces: return + servers = [interface.server for interface in ifaces] + self.print_error("forcing disconnect of other interfaces: {}".format(servers)) + async def _search_headers_backwards(self, height, header): async def iterate(): nonlocal height, header diff --git a/electrum/network.py b/electrum/network.py index a010588ad..4b76d7d30 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -32,7 +32,7 @@ import json import sys import ipaddress import asyncio -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Sequence import dns import dns.resolver @@ -43,6 +43,7 @@ from .util import PrintError, print_error, aiosafe, bfh from .bitcoin import COIN from . import constants from . import blockchain +from .blockchain import Blockchain from .interface import Interface, serialize_server, deserialize_server from .version import PROTOCOL_VERSION from .simple_config import SimpleConfig @@ -708,14 +709,22 @@ class Network(PrintError): @with_interface_lock def get_blockchains(self): - out = {} + out = {} # blockchain_id -> list(interfaces) with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items()) - for k, b in blockchain_items: - r = list(filter(lambda i: i.blockchain==b, list(self.interfaces.values()))) + for chain_id, bc in blockchain_items: + r = list(filter(lambda i: i.blockchain==bc, list(self.interfaces.values()))) if r: - out[k] = r + out[chain_id] = r return out + @with_interface_lock + def disconnect_from_interfaces_on_given_blockchain(self, chain: Blockchain) -> Sequence[Interface]: + chain_id = chain.forkpoint + ifaces = self.get_blockchains().get(chain_id) or [] + for interface in ifaces: + self.connection_down(interface.server) + return ifaces + def follow_chain(self, index): bc = blockchain.blockchains.get(index) if bc: