Browse Source

interface: refuse to overwrite blockchain of main interface

in case of conflicting forks
3.3.3.1
SomberNight 6 years ago
parent
commit
9161e8c8f4
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 18
      electrum/interface.py
  2. 19
      electrum/network.py

18
electrum/interface.py

@ -503,11 +503,14 @@ class Interface(PrintError):
# is assumed to be expensive; especially as forks below the max # is assumed to be expensive; especially as forks below the max
# checkpoint are ignored. # checkpoint are ignored.
self.print_error("new fork at bad height {}. conflict!!".format(bad)) self.print_error("new fork at bad height {}. conflict!!".format(bad))
assert self.blockchain != branch
ismocking = type(branch) is dict ismocking = type(branch) is dict
if ismocking: if ismocking:
self.print_error("TODO replace blockchain") self.print_error("TODO replace blockchain")
return 'fork_conflict', height return 'fork_conflict', height
self.print_error('forkpoint conflicts with existing fork', branch.path()) self.print_error('forkpoint conflicts with existing fork', branch.path())
self._raise_if_fork_conflicts_with_default_server(branch)
self._disconnect_from_interfaces_on_conflicting_blockchain(branch)
branch.write(b'', 0) branch.write(b'', 0)
branch.save_header(bad_header) branch.save_header(bad_header)
self.blockchain = branch self.blockchain = branch
@ -524,6 +527,21 @@ class Interface(PrintError):
assert b.forkpoint == bad assert b.forkpoint == bad
return 'fork_noconflict', height return 'fork_noconflict', height
def _raise_if_fork_conflicts_with_default_server(self, chain_to_delete: Blockchain) -> None:
main_interface = self.network.interface
if not main_interface: return
if main_interface == self: return
chain_of_default_server = main_interface.blockchain
if not chain_of_default_server: return
if chain_to_delete == chain_of_default_server:
raise GracefulDisconnect('refusing to overwrite blockchain of default server')
def _disconnect_from_interfaces_on_conflicting_blockchain(self, chain: Blockchain) -> None:
ifaces = self.network.disconnect_from_interfaces_on_given_blockchain(chain)
if not ifaces: return
servers = [interface.server for interface in ifaces]
self.print_error("forcing disconnect of other interfaces: {}".format(servers))
async def _search_headers_backwards(self, height, header): async def _search_headers_backwards(self, height, header):
async def iterate(): async def iterate():
nonlocal height, header nonlocal height, header

19
electrum/network.py

@ -32,7 +32,7 @@ import json
import sys import sys
import ipaddress import ipaddress
import asyncio import asyncio
from typing import NamedTuple, Optional from typing import NamedTuple, Optional, Sequence
import dns import dns
import dns.resolver import dns.resolver
@ -43,6 +43,7 @@ from .util import PrintError, print_error, aiosafe, bfh
from .bitcoin import COIN from .bitcoin import COIN
from . import constants from . import constants
from . import blockchain from . import blockchain
from .blockchain import Blockchain
from .interface import Interface, serialize_server, deserialize_server from .interface import Interface, serialize_server, deserialize_server
from .version import PROTOCOL_VERSION from .version import PROTOCOL_VERSION
from .simple_config import SimpleConfig from .simple_config import SimpleConfig
@ -708,14 +709,22 @@ class Network(PrintError):
@with_interface_lock @with_interface_lock
def get_blockchains(self): def get_blockchains(self):
out = {} out = {} # blockchain_id -> list(interfaces)
with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items()) with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items())
for k, b in blockchain_items: for chain_id, bc in blockchain_items:
r = list(filter(lambda i: i.blockchain==b, list(self.interfaces.values()))) r = list(filter(lambda i: i.blockchain==bc, list(self.interfaces.values())))
if r: if r:
out[k] = r out[chain_id] = r
return out return out
@with_interface_lock
def disconnect_from_interfaces_on_given_blockchain(self, chain: Blockchain) -> Sequence[Interface]:
chain_id = chain.forkpoint
ifaces = self.get_blockchains().get(chain_id) or []
for interface in ifaces:
self.connection_down(interface.server)
return ifaces
def follow_chain(self, index): def follow_chain(self, index):
bc = blockchain.blockchains.get(index) bc = blockchain.blockchains.get(index)
if bc: if bc:

Loading…
Cancel
Save