From 141ff99580192c920bc6bb7f6bbc9d35449daea8 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 20 Nov 2018 18:57:16 +0100 Subject: [PATCH 1/2] blockchain.py: generalise fork ids to get rid of conflicts --- electrum/blockchain.py | 252 ++++++++++++++++++++---------- electrum/gui/kivy/main_window.py | 8 +- electrum/gui/qt/network_dialog.py | 21 +-- electrum/interface.py | 65 ++------ electrum/network.py | 38 ++--- electrum/storage.py | 2 +- electrum/tests/test_blockchain.py | 239 ++++++++++++++++++++++++++++ electrum/tests/test_network.py | 20 ++- electrum/tests/test_wallet.py | 1 - 9 files changed, 474 insertions(+), 172 deletions(-) create mode 100644 electrum/tests/test_blockchain.py diff --git a/electrum/blockchain.py b/electrum/blockchain.py index 5609c41e2..c7c075312 100644 --- a/electrum/blockchain.py +++ b/electrum/blockchain.py @@ -79,26 +79,67 @@ def hash_raw_header(header: str) -> str: return hash_encode(sha256d(bfh(header))) -blockchains = {} # type: Dict[int, Blockchain] -blockchains_lock = threading.Lock() - - -def read_blockchains(config: 'SimpleConfig') -> Dict[int, 'Blockchain']: - blockchains[0] = Blockchain(config, 0, None) +# key: blockhash hex at forkpoint +# the chain at some key is the best chain that includes the given hash +blockchains = {} # type: Dict[str, Blockchain] +blockchains_lock = threading.RLock() + + +def read_blockchains(config: 'SimpleConfig'): + blockchains[constants.net.GENESIS] = Blockchain(config=config, + forkpoint=0, + parent=None, + forkpoint_hash=constants.net.GENESIS, + prev_hash=None) fdir = os.path.join(util.get_headers_dir(config), 'forks') util.make_dir(fdir) - l = filter(lambda x: x.startswith('fork_'), os.listdir(fdir)) - l = sorted(l, key = lambda x: int(x.split('_')[1])) - for filename in l: - forkpoint = int(filename.split('_')[2]) - parent_id = int(filename.split('_')[1]) - b = Blockchain(config, forkpoint, parent_id) - h = b.read_header(b.forkpoint) - if b.parent().can_connect(h, check_height=False): - blockchains[b.forkpoint] = b + # files are named as: fork2_{forkpoint}_{prev_hash}_{first_hash} + l = filter(lambda x: x.startswith('fork2_') and '.' not in x, os.listdir(fdir)) + l = sorted(l, key=lambda x: int(x.split('_')[1])) # sort by forkpoint + + def delete_chain(filename, reason): + util.print_error("[blockchain]", reason, filename) + os.unlink(os.path.join(fdir, filename)) + + def instantiate_chain(filename): + __, forkpoint, prev_hash, first_hash = filename.split('_') + forkpoint = int(forkpoint) + prev_hash = (64-len(prev_hash)) * "0" + prev_hash # left-pad with zeroes + first_hash = (64-len(first_hash)) * "0" + first_hash + # forks below the max checkpoint are not allowed + if forkpoint <= constants.net.max_checkpoint(): + delete_chain(filename, "deleting fork below max checkpoint") + return + # find parent (sorting by forkpoint guarantees it's already instantiated) + for parent in blockchains.values(): + if parent.check_hash(forkpoint - 1, prev_hash): + break else: - util.print_error("cannot connect", filename) - return blockchains + delete_chain(filename, "cannot find parent for chain") + return + b = Blockchain(config=config, + forkpoint=forkpoint, + parent=parent, + forkpoint_hash=first_hash, + prev_hash=prev_hash) + # consistency checks + h = b.read_header(b.forkpoint) + if first_hash != hash_header(h): + delete_chain(filename, "incorrect first hash for chain") + return + if not b.parent.can_connect(h, check_height=False): + delete_chain(filename, "cannot connect chain to parent") + return + chain_id = b.get_id() + assert first_hash == chain_id, (first_hash, chain_id) + blockchains[chain_id] = b + + for filename in l: + instantiate_chain(filename) + + +def get_best_chain() -> 'Blockchain': + return blockchains[constants.net.GENESIS] class Blockchain(util.PrintError): @@ -106,15 +147,20 @@ class Blockchain(util.PrintError): Manages blockchain headers and their verification """ - def __init__(self, config: SimpleConfig, forkpoint: int, parent_id: Optional[int]): + def __init__(self, config: SimpleConfig, forkpoint: int, parent: Optional['Blockchain'], + forkpoint_hash: str, prev_hash: Optional[str]): + assert isinstance(forkpoint_hash, str) and len(forkpoint_hash) == 64, forkpoint_hash + assert (prev_hash is None) or (isinstance(prev_hash, str) and len(prev_hash) == 64), prev_hash + # assert (parent is None) == (forkpoint == 0) + if 0 < forkpoint <= constants.net.max_checkpoint(): + raise Exception(f"cannot fork below max checkpoint. forkpoint: {forkpoint}") self.config = config - self.forkpoint = forkpoint - self.checkpoints = constants.net.CHECKPOINTS - self.parent_id = parent_id - assert parent_id != forkpoint + self.forkpoint = forkpoint # height of first header + self.parent = parent + self._forkpoint_hash = forkpoint_hash # blockhash at forkpoint. "first hash" + self._prev_hash = prev_hash # blockhash immediately before forkpoint self.lock = threading.RLock() - with self.lock: - self.update_size() + self.update_size() def with_lock(func): def func_wrapper(self, *args, **kwargs): @@ -122,12 +168,13 @@ class Blockchain(util.PrintError): return func(self, *args, **kwargs) return func_wrapper - def parent(self) -> 'Blockchain': - return blockchains[self.parent_id] + @property + def checkpoints(self): + return constants.net.CHECKPOINTS def get_max_child(self) -> Optional[int]: with blockchains_lock: chains = list(blockchains.values()) - children = list(filter(lambda y: y.parent_id==self.forkpoint, chains)) + children = list(filter(lambda y: y.parent==self, chains)) return max([x.forkpoint for x in children]) if children else None def get_max_forkpoint(self) -> int: @@ -137,11 +184,12 @@ class Blockchain(util.PrintError): mc = self.get_max_child() return mc if mc is not None else self.forkpoint + @with_lock def get_branch_size(self) -> int: return self.height() - self.get_max_forkpoint() + 1 def get_name(self) -> str: - return self.get_hash(self.get_max_forkpoint()).lstrip('00')[0:10] + return self.get_hash(self.get_max_forkpoint()).lstrip('0')[0:10] def check_header(self, header: dict) -> bool: header_hash = hash_header(header) @@ -159,24 +207,38 @@ class Blockchain(util.PrintError): return False def fork(parent, header: dict) -> 'Blockchain': + if not parent.can_connect(header, check_height=False): + raise Exception("forking header does not connect to parent chain") forkpoint = header.get('block_height') - self = Blockchain(parent.config, forkpoint, parent.forkpoint) + self = Blockchain(config=parent.config, + forkpoint=forkpoint, + parent=parent, + forkpoint_hash=hash_header(header), + prev_hash=parent.get_hash(forkpoint-1)) open(self.path(), 'w+').close() self.save_header(header) + # put into global dict + chain_id = self.get_id() + with blockchains_lock: + assert chain_id not in blockchains, (chain_id, list(blockchains)) + blockchains[chain_id] = self return self + @with_lock def height(self) -> int: return self.forkpoint + self.size() - 1 + @with_lock def size(self) -> int: - with self.lock: - return self._size + return self._size + @with_lock def update_size(self) -> None: p = self.path() self._size = os.path.getsize(p)//HEADER_SIZE if os.path.exists(p) else 0 - def verify_header(self, header: dict, prev_hash: str, target: int, expected_header_hash: str=None) -> None: + @classmethod + def verify_header(cls, header: dict, prev_hash: str, target: int, expected_header_hash: str=None) -> None: _hash = hash_header(header) if expected_header_hash and expected_header_hash != _hash: raise Exception("hash mismatches with expected: {} vs {}".format(expected_header_hash, _hash)) @@ -184,7 +246,7 @@ class Blockchain(util.PrintError): raise Exception("prev hash mismatch: %s vs %s" % (prev_hash, header.get('prev_block_hash'))) if constants.net.TESTNET: return - bits = self.target_to_bits(target) + bits = cls.target_to_bits(target) if bits != header.get('bits'): raise Exception("bits mismatch: %s vs %s" % (bits, header.get('bits'))) if int('0x' + _hash, 16) > target: @@ -206,21 +268,26 @@ class Blockchain(util.PrintError): self.verify_header(header, prev_hash, target, expected_header_hash) prev_hash = hash_header(header) + @with_lock def path(self): d = util.get_headers_dir(self.config) - if self.parent_id is None: + if self.parent is None: filename = 'blockchain_headers' else: - basename = 'fork_%d_%d' % (self.parent_id, self.forkpoint) + assert self.forkpoint > 0, self.forkpoint + prev_hash = self._prev_hash.lstrip('0') + first_hash = self._forkpoint_hash.lstrip('0') + basename = f'fork2_{self.forkpoint}_{prev_hash}_{first_hash}' filename = os.path.join('forks', basename) return os.path.join(d, filename) @with_lock def save_chunk(self, index: int, chunk: bytes): + assert index >= 0, index chunk_within_checkpoint_region = index < len(self.checkpoints) # chunks in checkpoint region are the responsibility of the 'main chain' - if chunk_within_checkpoint_region and self.parent_id is not None: - main_chain = blockchains[0] + if chunk_within_checkpoint_region and self.parent is not None: + main_chain = get_best_chain() main_chain.save_chunk(index, chunk) return @@ -235,18 +302,36 @@ class Blockchain(util.PrintError): self.write(chunk, delta_bytes, truncate) self.swap_with_parent() - @with_lock def swap_with_parent(self) -> None: - if self.parent_id is None: - return - parent_branch_size = self.parent().height() - self.forkpoint + 1 - if parent_branch_size >= self.size(): - return - self.print_error("swap", self.forkpoint, self.parent_id) - parent_id = self.parent_id - forkpoint = self.forkpoint - parent = self.parent() + parent_lock = self.parent.lock if self.parent is not None else threading.Lock() + with parent_lock, self.lock, blockchains_lock: # this order should not deadlock + # do the swap; possibly multiple ones + cnt = 0 + while self._swap_with_parent(): + cnt += 1 + if cnt > len(blockchains): # make sure we are making progress + raise Exception(f'swapping fork with parent too many times: {cnt}') + + def _swap_with_parent(self) -> bool: + """Check if this chain became stronger than its parent, and swap + the underlying files if so. The Blockchain instances will keep + 'containing' the same headers, but their ids change and so + they will be stored in different files.""" + if self.parent is None: + return False + parent_branch_size = self.parent.height() - self.forkpoint + 1 + if parent_branch_size >= self.size(): # FIXME most work, not length + return False + self.print_error("swap", self.forkpoint, self.parent.forkpoint) + forkpoint = self.forkpoint # type: Optional[int] + parent = self.parent # type: Optional[Blockchain] + child_old_id = self.get_id() + parent_old_id = parent.get_id() + # swap files + # child takes parent's name + # parent's new name will be something new (not child's old name) self.assert_headers_file_available(self.path()) + child_old_name = self.path() with open(self.path(), 'rb') as f: my_data = f.read() self.assert_headers_file_available(parent.path()) @@ -255,24 +340,28 @@ class Blockchain(util.PrintError): parent_data = f.read(parent_branch_size*HEADER_SIZE) self.write(parent_data, 0) parent.write(my_data, (forkpoint - parent.forkpoint)*HEADER_SIZE) - # store file path - with blockchains_lock: chains = list(blockchains.values()) - for b in chains: - b.old_path = b.path() # swap parameters - self.parent_id = parent.parent_id; parent.parent_id = parent_id - self.forkpoint = parent.forkpoint; parent.forkpoint = forkpoint - self._size = parent._size; parent._size = parent_branch_size - # move files - for b in chains: - if b in [self, parent]: continue - if b.old_path != b.path(): - self.print_error("renaming", b.old_path, b.path()) - os.rename(b.old_path, b.path()) + self.parent, parent.parent = parent.parent, self # type: Optional[Blockchain], Optional[Blockchain] + self.forkpoint, parent.forkpoint = parent.forkpoint, self.forkpoint + self._forkpoint_hash, parent._forkpoint_hash = parent._forkpoint_hash, hash_raw_header(bh2u(parent_data[:HEADER_SIZE])) + self._prev_hash, parent._prev_hash = parent._prev_hash, self._prev_hash + # parent's new name + try: + os.rename(child_old_name, parent.path()) + except OSError: + os.remove(parent.path()) + os.rename(child_old_name, parent.path()) + self.update_size() + parent.update_size() # update pointers - with blockchains_lock: - blockchains[self.forkpoint] = self - blockchains[parent.forkpoint] = parent + blockchains.pop(child_old_id, None) + blockchains.pop(parent_old_id, None) + blockchains[self.get_id()] = self + blockchains[parent.get_id()] = parent + return True + + def get_id(self) -> str: + return self._forkpoint_hash def assert_headers_file_available(self, path): if os.path.exists(path): @@ -282,19 +371,19 @@ class Blockchain(util.PrintError): else: raise FileNotFoundError('Cannot find headers file but headers_dir is there. Should be at {}'.format(path)) + @with_lock def write(self, data: bytes, offset: int, truncate: bool=True) -> None: filename = self.path() - with self.lock: - self.assert_headers_file_available(filename) - with open(filename, 'rb+') as f: - if truncate and offset != self._size * HEADER_SIZE: - f.seek(offset) - f.truncate() + self.assert_headers_file_available(filename) + with open(filename, 'rb+') as f: + if truncate and offset != self._size * HEADER_SIZE: f.seek(offset) - f.write(data) - f.flush() - os.fsync(f.fileno()) - self.update_size() + f.truncate() + f.seek(offset) + f.write(data) + f.flush() + os.fsync(f.fileno()) + self.update_size() @with_lock def save_header(self, header: dict) -> None: @@ -306,12 +395,12 @@ class Blockchain(util.PrintError): self.write(data, delta*HEADER_SIZE) self.swap_with_parent() + @with_lock def read_header(self, height: int) -> Optional[dict]: - assert self.parent_id != self.forkpoint if height < 0: return if height < self.forkpoint: - return self.parent().read_header(height) + return self.parent.read_header(height) if height > self.height(): return delta = height - self.forkpoint @@ -371,16 +460,18 @@ class Blockchain(util.PrintError): new_target = self.bits_to_target(self.target_to_bits(new_target)) return new_target - def bits_to_target(self, bits: int) -> int: + @classmethod + def bits_to_target(cls, bits: int) -> int: bitsN = (bits >> 24) & 0xff - if not (bitsN >= 0x03 and bitsN <= 0x1d): + if not (0x03 <= bitsN <= 0x1d): raise Exception("First part of bits should be in [0x03, 0x1d]") bitsBase = bits & 0xffffff - if not (bitsBase >= 0x8000 and bitsBase <= 0x7fffff): + if not (0x8000 <= bitsBase <= 0x7fffff): raise Exception("Second part of bits should be in [0x8000, 0x7fffff]") return bitsBase << (8 * (bitsN-3)) - def target_to_bits(self, target: int) -> int: + @classmethod + def target_to_bits(cls, target: int) -> int: c = ("%064x" % target)[2:] while c[:2] == '00' and len(c) > 6: c = c[2:] @@ -416,6 +507,7 @@ class Blockchain(util.PrintError): return True def connect_chunk(self, idx: int, hexdata: str) -> bool: + assert idx >= 0, idx try: data = bfh(hexdata) self.verify_chunk(idx, data) @@ -423,7 +515,7 @@ class Blockchain(util.PrintError): self.save_chunk(idx, data) return True except BaseException as e: - self.print_error('verify_chunk %d failed'%idx, str(e)) + self.print_error(f'verify_chunk idx {idx} failed: {repr(e)}') return False def get_checkpoints(self): diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py index c34e5ef9f..73379bb22 100644 --- a/electrum/gui/kivy/main_window.py +++ b/electrum/gui/kivy/main_window.py @@ -126,10 +126,12 @@ class ElectrumWindow(App): chains = self.network.get_blockchains() def cb(name): with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items()) - for index, b in blockchain_items: + for chain_id, b in blockchain_items: if name == b.get_name(): - self.network.run_from_another_thread(self.network.follow_chain_given_id(index)) - names = [blockchain.blockchains[b].get_name() for b in chains] + self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id)) + chain_objects = [blockchain.blockchains.get(chain_id) for chain_id in chains] + chain_objects = filter(lambda b: b is not None, chain_objects) + names = [b.get_name() for b in chain_objects] if len(names) > 1: cur_chain = self.network.blockchain().get_name() ChoiceDialog(_('Choose your chain'), names, cur_chain, cb).open() diff --git a/electrum/gui/qt/network_dialog.py b/electrum/gui/qt/network_dialog.py index bef853830..94ae77735 100644 --- a/electrum/gui/qt/network_dialog.py +++ b/electrum/gui/qt/network_dialog.py @@ -82,8 +82,8 @@ class NodesListWidget(QTreeWidget): server = item.data(1, Qt.UserRole) menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server)) else: - index = item.data(1, Qt.UserRole) - menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(index)) + chain_id = item.data(1, Qt.UserRole) + menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id)) menu.exec_(self.viewport().mapToGlobal(position)) def keyPressEvent(self, event): @@ -103,22 +103,23 @@ class NodesListWidget(QTreeWidget): self.addChild = self.addTopLevelItem chains = network.get_blockchains() n_chains = len(chains) - for k, items in chains.items(): - b = blockchain.blockchains[k] + for chain_id, interfaces in chains.items(): + b = blockchain.blockchains.get(chain_id) + if b is None: continue name = b.get_name() - if n_chains >1: + if n_chains > 1: x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()]) x.setData(0, Qt.UserRole, 1) - x.setData(1, Qt.UserRole, b.forkpoint) + x.setData(1, Qt.UserRole, b.get_id()) else: x = self - for i in items: + for i in interfaces: star = ' *' if i == network.interface else '' item = QTreeWidgetItem([i.host + star, '%d'%i.tip]) item.setData(0, Qt.UserRole, 0) item.setData(1, Qt.UserRole, i.server) x.addChild(item) - if n_chains>1: + if n_chains > 1: self.addTopLevelItem(x) x.setExpanded(True) @@ -410,8 +411,8 @@ class NetworkChoiceLayout(object): self.set_protocol(p) self.set_server() - def follow_branch(self, index): - self.network.run_from_another_thread(self.network.follow_chain_given_id(index)) + def follow_branch(self, chain_id): + self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id)) self.update() def follow_server(self, server): diff --git a/electrum/interface.py b/electrum/interface.py index 68ede7554..99e2349e4 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -28,7 +28,7 @@ import ssl import sys import traceback import asyncio -from typing import Tuple, Union, List, TYPE_CHECKING +from typing import Tuple, Union, List, TYPE_CHECKING, Optional from collections import defaultdict import aiorpcx @@ -140,14 +140,14 @@ def serialize_server(host: str, port: Union[str, int], protocol: str) -> str: class Interface(PrintError): verbosity_filter = 'i' - def __init__(self, network: 'Network', server: str, config_path, proxy: dict): + def __init__(self, network: 'Network', server: str, proxy: Optional[dict]): self.ready = asyncio.Future() self.got_disconnected = asyncio.Future() self.server = server self.host, self.port, self.protocol = deserialize_server(self.server) self.port = int(self.port) - self.config_path = config_path - self.cert_path = os.path.join(self.config_path, 'certs', self.host) + assert network.config.path + self.cert_path = os.path.join(network.config.path, 'certs', self.host) self.blockchain = None self._requested_chunks = set() self.network = network @@ -281,7 +281,7 @@ class Interface(PrintError): assert self.tip_header chain = blockchain.check_header(self.tip_header) if not chain: - self.blockchain = blockchain.blockchains[0] + self.blockchain = blockchain.get_best_chain() else: self.blockchain = chain assert self.blockchain is not None @@ -502,7 +502,7 @@ class Interface(PrintError): # bad_header connects to good_header; bad_header itself is NOT in self.blockchain. bh = self.blockchain.height() - assert bh >= good + assert bh >= good, (bh, good) if bh == good: height = good + 1 self.print_error("catching up from {}".format(height)) @@ -510,53 +510,12 @@ class Interface(PrintError): # this is a new fork we don't yet have height = bad + 1 - branch = blockchain.blockchains.get(bad) - if branch is not None: - # Conflict!! As our fork handling is not completely general, - # we need to delete another fork to save this one. - # Note: This could be a potential DOS vector against Electrum. - # However, mining blocks that satisfy the difficulty requirements - # is assumed to be expensive; especially as forks below the max - # checkpoint are ignored. - self.print_error("new fork at bad height {}. conflict!!".format(bad)) - assert self.blockchain != branch - ismocking = type(branch) is dict - if ismocking: - self.print_error("TODO replace blockchain") - return 'fork_conflict', height - self.print_error('forkpoint conflicts with existing fork', branch.path()) - self._raise_if_fork_conflicts_with_default_server(branch) - await self._disconnect_from_interfaces_on_conflicting_blockchain(branch) - branch.write(b'', 0) - branch.save_header(bad_header) - self.blockchain = branch - return 'fork_conflict', height - else: - # No conflict. Just save the new fork. - self.print_error("new fork at bad height {}. NO conflict.".format(bad)) - forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork'] - b = forkfun(bad_header) - with blockchain.blockchains_lock: - assert bad not in blockchain.blockchains, (bad, list(blockchain.blockchains)) - blockchain.blockchains[bad] = b - self.blockchain = b - assert b.forkpoint == bad - return 'fork_noconflict', height - - def _raise_if_fork_conflicts_with_default_server(self, chain_to_delete: Blockchain) -> None: - main_interface = self.network.interface - if not main_interface: return - if main_interface == self: return - chain_of_default_server = main_interface.blockchain - if not chain_of_default_server: return - if chain_to_delete == chain_of_default_server: - raise GracefulDisconnect('refusing to overwrite blockchain of default server') - - async def _disconnect_from_interfaces_on_conflicting_blockchain(self, chain: Blockchain) -> None: - ifaces = await self.network.disconnect_from_interfaces_on_given_blockchain(chain) - if not ifaces: return - servers = [interface.server for interface in ifaces] - self.print_error("forcing disconnect of other interfaces: {}".format(servers)) + self.print_error(f"new fork at bad height {bad}") + forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork'] + b = forkfun(bad_header) # type: Blockchain + self.blockchain = b + assert b.forkpoint == bad + return 'fork', height async def _search_headers_backwards(self, height, header): async def iterate(): diff --git a/electrum/network.py b/electrum/network.py index c0218f9cc..85c8dbee1 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -177,10 +177,10 @@ class Network(PrintError): if config is None: config = {} # Do not use mutables as default values! self.config = SimpleConfig(config) if isinstance(config, dict) else config # type: SimpleConfig - blockchain.blockchains = blockchain.read_blockchains(self.config) - self.print_error("blockchains", list(blockchain.blockchains)) + blockchain.read_blockchains(self.config) + self.print_error("blockchains", list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))) self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict] - self._blockchain_index = 0 + self._blockchain = blockchain.get_best_chain() # Server for addresses and transactions self.default_server = self.config.get('server', None) # Sanitize default server @@ -559,17 +559,24 @@ class Network(PrintError): filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash), interfaces)) if filtered: + self.print_error("switching to preferred fork") chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) return - # try to switch to longest chain - if self.blockchain().parent_id is None: - return # already on longest chain - filtered = list(filter(lambda iface: iface.blockchain.parent_id is None, + else: + self.print_error("tried to switch to preferred fork but no interfaces are on it") + # try to switch to best chain + if self.blockchain().parent is None: + return # already on best chain + filtered = list(filter(lambda iface: iface.blockchain.parent is None, interfaces)) if filtered: + self.print_error("switching to best chain") chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) + else: + # FIXME switch to best available? + self.print_error("tried to switch to best chain but no interfaces are on it") async def switch_to_interface(self, server: str): """Switch to server as our main interface. If no connection exists, @@ -637,7 +644,7 @@ class Network(PrintError): @ignore_exceptions # do not kill main_taskgroup @log_exceptions async def _run_new_interface(self, server): - interface = Interface(self, server, self.config.path, self.proxy) + interface = Interface(self, server, self.proxy) timeout = 10 if not self.proxy else 20 try: await asyncio.wait_for(interface.ready, timeout) @@ -661,7 +668,7 @@ class Network(PrintError): self.trigger_callback('network_updated') async def _init_headers_file(self): - b = blockchain.blockchains[0] + b = blockchain.get_best_chain() filename = b.path() length = HEADER_SIZE * len(constants.net.CHECKPOINTS) * 2016 if not os.path.exists(filename) or os.path.getsize(filename) < length: @@ -739,8 +746,8 @@ class Network(PrintError): def blockchain(self) -> Blockchain: interface = self.interface if interface and interface.blockchain is not None: - self._blockchain_index = interface.blockchain.forkpoint - return blockchain.blockchains[self._blockchain_index] + self._blockchain = interface.blockchain + return self._blockchain def get_blockchains(self): out = {} # blockchain_id -> list(interfaces) @@ -752,13 +759,6 @@ class Network(PrintError): out[chain_id] = r return out - async def disconnect_from_interfaces_on_given_blockchain(self, chain: Blockchain) -> Sequence[Interface]: - chain_id = chain.forkpoint - ifaces = self.get_blockchains().get(chain_id) or [] - for interface in ifaces: - await self.connection_down(interface.server) - return ifaces - def _set_preferred_chain(self, chain: Blockchain): height = chain.get_max_forkpoint() header_hash = chain.get_hash(height) @@ -768,7 +768,7 @@ class Network(PrintError): } self.config.set_key('blockchain_preferred_block', self._blockchain_preferred_block) - async def follow_chain_given_id(self, chain_id: int) -> None: + async def follow_chain_given_id(self, chain_id: str) -> None: bc = blockchain.blockchains.get(chain_id) if not bc: raise Exception('blockchain {} not found'.format(chain_id)) diff --git a/electrum/storage.py b/electrum/storage.py index 16a4cc90d..ad3de4c6a 100644 --- a/electrum/storage.py +++ b/electrum/storage.py @@ -125,7 +125,7 @@ class JsonDB(PrintError): # perform atomic write on POSIX systems try: os.rename(temp_path, self.path) - except: + except OSError: os.remove(self.path) os.rename(temp_path, self.path) os.chmod(self.path, mode) diff --git a/electrum/tests/test_blockchain.py b/electrum/tests/test_blockchain.py new file mode 100644 index 000000000..be29c1b03 --- /dev/null +++ b/electrum/tests/test_blockchain.py @@ -0,0 +1,239 @@ +import shutil +import tempfile +import os + +from electrum import constants, blockchain +from electrum.simple_config import SimpleConfig +from electrum.blockchain import Blockchain, deserialize_header, hash_header +from electrum.util import bh2u, bfh, make_dir + +from . import SequentialTestCase + + +class TestBlockchain(SequentialTestCase): + + HEADERS = { + 'A': deserialize_header(bfh("0100000000000000000000000000000000000000000000000000000000000000000000003ba3edfd7a7b12b27ac72c3e67768f617fc81bc3888a51323a9fb8aa4b1e5e4adae5494dffff7f2002000000"), 0), + 'B': deserialize_header(bfh("0000002006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f186c8dfd970a4545f79916bc1d75c9d00432f57c89209bf3bb115b7612848f509c25f45bffff7f2000000000"), 1), + 'C': deserialize_header(bfh("00000020686bdfc6a3db73d5d93e8c9663a720a26ecb1ef20eb05af11b36cdbc57c19f7ebf2cbf153013a1c54abaf70e95198fcef2f3059cc6b4d0f7e876808e7d24d11cc825f45bffff7f2000000000"), 2), + 'D': deserialize_header(bfh("00000020122baa14f3ef54985ae546d1611559e3f487bd2a0f46e8dbb52fbacc9e237972e71019d7feecd9b8596eca9a67032c5f4641b23b5d731dc393e37de7f9c2f299e725f45bffff7f2000000000"), 3), + 'E': deserialize_header(bfh("00000020f8016f7ef3a17d557afe05d4ea7ab6bde1b2247b7643896c1b63d43a1598b747a3586da94c71753f27c075f57f44faf913c31177a0957bbda42e7699e3a2141aed25f45bffff7f2001000000"), 4), + 'F': deserialize_header(bfh("000000201d589c6643c1d121d73b0573e5ee58ab575b8fdf16d507e7e915c5fbfbbfd05e7aee1d692d1615c3bdf52c291032144ce9e3b258a473c17c745047f3431ff8e2ee25f45bffff7f2000000000"), 5), + 'O': deserialize_header(bfh("00000020b833ed46eea01d4c980f59feee44a66aa1162748b6801029565d1466790c405c3a141ce635cbb1cd2b3a4fcdd0a3380517845ba41736c82a79cab535d31128066526f45bffff7f2001000000"), 6), + 'P': deserialize_header(bfh("00000020abe8e119d1877c9dc0dc502d1a253fb9a67967c57732d2f71ee0280e8381ff0a9690c2fe7c1a4450c74dc908fe94dd96c3b0637d51475e9e06a78e944a0c7fe28126f45bffff7f2000000000"), 7), + 'Q': deserialize_header(bfh("000000202ce41d94eb70e1518bc1f72523f84a903f9705d967481e324876e1f8cf4d3452148be228a4c3f2061bafe7efdfc4a8d5a94759464b9b5c619994d45dfcaf49e1a126f45bffff7f2000000000"), 8), + 'R': deserialize_header(bfh("00000020552755b6c59f3d51e361d16281842a4e166007799665b5daed86a063dd89857415681cb2d00ff889193f6a68a93f5096aeb2d84ca0af6185a462555822552221a626f45bffff7f2000000000"), 9), + 'S': deserialize_header(bfh("00000020a13a491cbefc93cd1bb1938f19957e22a134faf14c7dee951c45533e2c750f239dc087fc977b06c24a69c682d1afd1020e6dc1f087571ccec66310a786e1548fab26f45bffff7f2000000000"), 10), + 'T': deserialize_header(bfh("00000020dbf3a9b55dfefbaf8b6e43a89cf833fa2e208bbc0c1c5d76c0d71b9e4a65337803b243756c25053253aeda309604363460a3911015929e68705bd89dff6fe064b026f45bffff7f2002000000"), 11), + 'U': deserialize_header(bfh("000000203d0932b3b0c78eccb39a595a28ae4a7c966388648d7783fd1305ec8d40d4fe5fd67cb902a7d807cee7676cb543feec3e053aa824d5dfb528d5b94f9760313d9db726f45bffff7f2001000000"), 12), + 'G': deserialize_header(bfh("00000020b833ed46eea01d4c980f59feee44a66aa1162748b6801029565d1466790c405c3a141ce635cbb1cd2b3a4fcdd0a3380517845ba41736c82a79cab535d31128066928f45bffff7f2001000000"), 6), + 'H': deserialize_header(bfh("00000020e19e687f6e7f83ca394c114144dbbbc4f3f9c9450f66331a125413702a2e1a719690c2fe7c1a4450c74dc908fe94dd96c3b0637d51475e9e06a78e944a0c7fe26a28f45bffff7f2002000000"), 7), + 'I': deserialize_header(bfh("0000002009dcb3b158293c89d7cf7ceeb513add122ebc3880a850f47afbb2747f5e48c54148be228a4c3f2061bafe7efdfc4a8d5a94759464b9b5c619994d45dfcaf49e16a28f45bffff7f2000000000"), 8), + 'J': deserialize_header(bfh("000000206a65f3bdd3374a5a6c4538008ba0b0a560b8566291f9ef4280ab877627a1742815681cb2d00ff889193f6a68a93f5096aeb2d84ca0af6185a462555822552221c928f45bffff7f2000000000"), 9), + 'K': deserialize_header(bfh("00000020bb3b421653548991998f96f8ba486b652fdb07ca16e9cee30ece033547cd1a6e9dc087fc977b06c24a69c682d1afd1020e6dc1f087571ccec66310a786e1548fca28f45bffff7f2000000000"), 10), + 'L': deserialize_header(bfh("00000020c391d74d37c24a130f4bf4737932bdf9e206dd4fad22860ec5408978eb55d46303b243756c25053253aeda309604363460a3911015929e68705bd89dff6fe064ca28f45bffff7f2000000000"), 11), + 'M': deserialize_header(bfh("000000206a65f3bdd3374a5a6c4538008ba0b0a560b8566291f9ef4280ab877627a1742815681cb2d00ff889193f6a68a93f5096aeb2d84ca0af6185a4625558225522214229f45bffff7f2000000000"), 9), + 'N': deserialize_header(bfh("00000020383dab38b57f98aa9b4f0d5ff868bc674b4828d76766bf048296f4c45fff680a9dc087fc977b06c24a69c682d1afd1020e6dc1f087571ccec66310a786e1548f4329f45bffff7f2003000000"), 10), + 'X': deserialize_header(bfh("0000002067f1857f54b7fef732cb4940f7d1b339472b3514660711a820330fd09d8fba6b03b243756c25053253aeda309604363460a3911015929e68705bd89dff6fe0649b29f45bffff7f2002000000"), 11), + 'Y': deserialize_header(bfh("00000020db33c9768a9e5f7c37d0f09aad88d48165946c87d08f7d63793f07b5c08c527fd67cb902a7d807cee7676cb543feec3e053aa824d5dfb528d5b94f9760313d9d9b29f45bffff7f2000000000"), 12), + 'Z': deserialize_header(bfh("0000002047822b67940e337fda38be6f13390b3596e4dea2549250256879722073824e7f0f2596c29203f8a0f71ae94193092dc8f113be3dbee4579f1e649fa3d6dcc38c622ef45bffff7f2003000000"), 13), + } + # tree of headers: + # - M <- N <- X <- Y <- Z + # / + # - G <- H <- I <- J <- K <- L + # / + # A <- B <- C <- D <- E <- F <- O <- P <- Q <- R <- S <- T <- U + + @classmethod + def setUpClass(cls): + super().setUpClass() + constants.set_regtest() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + constants.set_mainnet() + + def setUp(self): + super().setUp() + self.data_dir = tempfile.mkdtemp() + make_dir(os.path.join(self.data_dir, 'forks')) + self.config = SimpleConfig({'electrum_path': self.data_dir}) + blockchain.blockchains = {} + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.data_dir) + + def _append_header(self, chain: Blockchain, header: dict): + self.assertTrue(chain.can_connect(header)) + chain.save_header(header) + + def test_forking_and_swapping(self): + blockchain.blockchains[constants.net.GENESIS] = chain_u = Blockchain( + config=self.config, forkpoint=0, parent=None, + forkpoint_hash=constants.net.GENESIS, prev_hash=None) + open(chain_u.path(), 'w+').close() + + self._append_header(chain_u, self.HEADERS['A']) + self._append_header(chain_u, self.HEADERS['B']) + self._append_header(chain_u, self.HEADERS['C']) + self._append_header(chain_u, self.HEADERS['D']) + self._append_header(chain_u, self.HEADERS['E']) + self._append_header(chain_u, self.HEADERS['F']) + self._append_header(chain_u, self.HEADERS['O']) + self._append_header(chain_u, self.HEADERS['P']) + self._append_header(chain_u, self.HEADERS['Q']) + self._append_header(chain_u, self.HEADERS['R']) + + chain_l = chain_u.fork(self.HEADERS['G']) + self._append_header(chain_l, self.HEADERS['H']) + self._append_header(chain_l, self.HEADERS['I']) + self._append_header(chain_l, self.HEADERS['J']) + + # do checks + self.assertEqual(2, len(blockchain.blockchains)) + self.assertEqual(1, len(os.listdir(os.path.join(self.data_dir, "forks")))) + self.assertEqual(0, chain_u.forkpoint) + self.assertEqual(None, chain_u.parent) + self.assertEqual(constants.net.GENESIS, chain_u._forkpoint_hash) + self.assertEqual(None, chain_u._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "blockchain_headers"), chain_u.path()) + self.assertEqual(10 * 80, os.stat(chain_u.path()).st_size) + self.assertEqual(6, chain_l.forkpoint) + self.assertEqual(chain_u, chain_l.parent) + self.assertEqual(hash_header(self.HEADERS['G']), chain_l._forkpoint_hash) + self.assertEqual(hash_header(self.HEADERS['F']), chain_l._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "forks", "fork2_6_5c400c7966145d56291080b6482716a16aa644eefe590f984c1da0ee46ed33b8_711a2e2a701354121a33660f45c9f9f3c4bbdb4441114c39ca837f6e7f689ee1"), chain_l.path()) + self.assertEqual(4 * 80, os.stat(chain_l.path()).st_size) + + self._append_header(chain_l, self.HEADERS['K']) + + # chains were swapped, do checks + self.assertEqual(2, len(blockchain.blockchains)) + self.assertEqual(1, len(os.listdir(os.path.join(self.data_dir, "forks")))) + self.assertEqual(6, chain_u.forkpoint) + self.assertEqual(chain_l, chain_u.parent) + self.assertEqual(hash_header(self.HEADERS['O']), chain_u._forkpoint_hash) + self.assertEqual(hash_header(self.HEADERS['F']), chain_u._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "forks", "fork2_6_5c400c7966145d56291080b6482716a16aa644eefe590f984c1da0ee46ed33b8_aff81830e28e01ef7d23277c56779a6b93f251a2d50dcc09d7c87d119e1e8ab"), chain_u.path()) + self.assertEqual(4 * 80, os.stat(chain_u.path()).st_size) + self.assertEqual(0, chain_l.forkpoint) + self.assertEqual(None, chain_l.parent) + self.assertEqual(constants.net.GENESIS, chain_l._forkpoint_hash) + self.assertEqual(None, chain_l._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "blockchain_headers"), chain_l.path()) + self.assertEqual(11 * 80, os.stat(chain_l.path()).st_size) + for b in (chain_u, chain_l): + self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())])) + + self._append_header(chain_u, self.HEADERS['S']) + self._append_header(chain_u, self.HEADERS['T']) + self._append_header(chain_u, self.HEADERS['U']) + self._append_header(chain_l, self.HEADERS['L']) + + chain_z = chain_l.fork(self.HEADERS['M']) + self._append_header(chain_z, self.HEADERS['N']) + self._append_header(chain_z, self.HEADERS['X']) + self._append_header(chain_z, self.HEADERS['Y']) + self._append_header(chain_z, self.HEADERS['Z']) + + # chain_z became best chain, do checks + self.assertEqual(3, len(blockchain.blockchains)) + self.assertEqual(2, len(os.listdir(os.path.join(self.data_dir, "forks")))) + self.assertEqual(0, chain_z.forkpoint) + self.assertEqual(None, chain_z.parent) + self.assertEqual(constants.net.GENESIS, chain_z._forkpoint_hash) + self.assertEqual(None, chain_z._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "blockchain_headers"), chain_z.path()) + self.assertEqual(14 * 80, os.stat(chain_z.path()).st_size) + self.assertEqual(9, chain_l.forkpoint) + self.assertEqual(chain_z, chain_l.parent) + self.assertEqual(hash_header(self.HEADERS['J']), chain_l._forkpoint_hash) + self.assertEqual(hash_header(self.HEADERS['I']), chain_l._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "forks", "fork2_9_2874a1277687ab8042eff9916256b860a5b0a08b0038456c5a4a37d3bdf3656a_6e1acd473503ce0ee3cee916ca07db2f656b48baf8968f999189545316423bbb"), chain_l.path()) + self.assertEqual(3 * 80, os.stat(chain_l.path()).st_size) + self.assertEqual(6, chain_u.forkpoint) + self.assertEqual(chain_z, chain_u.parent) + self.assertEqual(hash_header(self.HEADERS['O']), chain_u._forkpoint_hash) + self.assertEqual(hash_header(self.HEADERS['F']), chain_u._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "forks", "fork2_6_5c400c7966145d56291080b6482716a16aa644eefe590f984c1da0ee46ed33b8_aff81830e28e01ef7d23277c56779a6b93f251a2d50dcc09d7c87d119e1e8ab"), chain_u.path()) + self.assertEqual(7 * 80, os.stat(chain_u.path()).st_size) + for b in (chain_u, chain_l, chain_z): + self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())])) + + self.assertEqual(constants.net.GENESIS, chain_z.get_hash(0)) + self.assertEqual(hash_header(self.HEADERS['F']), chain_z.get_hash(5)) + self.assertEqual(hash_header(self.HEADERS['G']), chain_z.get_hash(6)) + self.assertEqual(hash_header(self.HEADERS['I']), chain_z.get_hash(8)) + self.assertEqual(hash_header(self.HEADERS['M']), chain_z.get_hash(9)) + self.assertEqual(hash_header(self.HEADERS['Z']), chain_z.get_hash(13)) + + def test_doing_multiple_swaps_after_single_new_header(self): + blockchain.blockchains[constants.net.GENESIS] = chain_u = Blockchain( + config=self.config, forkpoint=0, parent=None, + forkpoint_hash=constants.net.GENESIS, prev_hash=None) + open(chain_u.path(), 'w+').close() + + self._append_header(chain_u, self.HEADERS['A']) + self._append_header(chain_u, self.HEADERS['B']) + self._append_header(chain_u, self.HEADERS['C']) + self._append_header(chain_u, self.HEADERS['D']) + self._append_header(chain_u, self.HEADERS['E']) + self._append_header(chain_u, self.HEADERS['F']) + self._append_header(chain_u, self.HEADERS['O']) + self._append_header(chain_u, self.HEADERS['P']) + self._append_header(chain_u, self.HEADERS['Q']) + self._append_header(chain_u, self.HEADERS['R']) + self._append_header(chain_u, self.HEADERS['S']) + + self.assertEqual(1, len(blockchain.blockchains)) + self.assertEqual(0, len(os.listdir(os.path.join(self.data_dir, "forks")))) + + chain_l = chain_u.fork(self.HEADERS['G']) + self._append_header(chain_l, self.HEADERS['H']) + self._append_header(chain_l, self.HEADERS['I']) + self._append_header(chain_l, self.HEADERS['J']) + self._append_header(chain_l, self.HEADERS['K']) + # now chain_u is best chain, but it's tied with chain_l + + self.assertEqual(2, len(blockchain.blockchains)) + self.assertEqual(1, len(os.listdir(os.path.join(self.data_dir, "forks")))) + + chain_z = chain_l.fork(self.HEADERS['M']) + self._append_header(chain_z, self.HEADERS['N']) + self._append_header(chain_z, self.HEADERS['X']) + + self.assertEqual(3, len(blockchain.blockchains)) + self.assertEqual(2, len(os.listdir(os.path.join(self.data_dir, "forks")))) + + # chain_z became best chain, do checks + self.assertEqual(0, chain_z.forkpoint) + self.assertEqual(None, chain_z.parent) + self.assertEqual(constants.net.GENESIS, chain_z._forkpoint_hash) + self.assertEqual(None, chain_z._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "blockchain_headers"), chain_z.path()) + self.assertEqual(12 * 80, os.stat(chain_z.path()).st_size) + self.assertEqual(9, chain_l.forkpoint) + self.assertEqual(chain_z, chain_l.parent) + self.assertEqual(hash_header(self.HEADERS['J']), chain_l._forkpoint_hash) + self.assertEqual(hash_header(self.HEADERS['I']), chain_l._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "forks", "fork2_9_2874a1277687ab8042eff9916256b860a5b0a08b0038456c5a4a37d3bdf3656a_6e1acd473503ce0ee3cee916ca07db2f656b48baf8968f999189545316423bbb"), chain_l.path()) + self.assertEqual(2 * 80, os.stat(chain_l.path()).st_size) + self.assertEqual(6, chain_u.forkpoint) + self.assertEqual(chain_z, chain_u.parent) + self.assertEqual(hash_header(self.HEADERS['O']), chain_u._forkpoint_hash) + self.assertEqual(hash_header(self.HEADERS['F']), chain_u._prev_hash) + self.assertEqual(os.path.join(self.data_dir, "forks", "fork2_6_5c400c7966145d56291080b6482716a16aa644eefe590f984c1da0ee46ed33b8_aff81830e28e01ef7d23277c56779a6b93f251a2d50dcc09d7c87d119e1e8ab"), chain_u.path()) + self.assertEqual(5 * 80, os.stat(chain_u.path()).st_size) + + self.assertEqual(constants.net.GENESIS, chain_z.get_hash(0)) + self.assertEqual(hash_header(self.HEADERS['F']), chain_z.get_hash(5)) + self.assertEqual(hash_header(self.HEADERS['G']), chain_z.get_hash(6)) + self.assertEqual(hash_header(self.HEADERS['I']), chain_z.get_hash(8)) + self.assertEqual(hash_header(self.HEADERS['M']), chain_z.get_hash(9)) + self.assertEqual(hash_header(self.HEADERS['X']), chain_z.get_hash(11)) + + for b in (chain_u, chain_l, chain_z): + self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())])) diff --git a/electrum/tests/test_network.py b/electrum/tests/test_network.py index c69375bd6..ece54056f 100644 --- a/electrum/tests/test_network.py +++ b/electrum/tests/test_network.py @@ -6,6 +6,9 @@ from electrum import constants from electrum.simple_config import SimpleConfig from electrum import blockchain from electrum.interface import Interface +from electrum.crypto import sha256 +from electrum.util import bh2u + class MockTaskGroup: async def spawn(self, x): return @@ -17,10 +20,14 @@ class MockNetwork: class MockInterface(Interface): def __init__(self, config): self.config = config - super().__init__(MockNetwork(), 'mock-server:50000:t', self.config.electrum_path(), None) + network = MockNetwork() + network.config = config + super().__init__(network, 'mock-server:50000:t', None) self.q = asyncio.Queue() - self.blockchain = blockchain.Blockchain(self.config, 2002, None) + self.blockchain = blockchain.Blockchain(config=self.config, forkpoint=0, + parent=None, forkpoint_hash=constants.net.GENESIS, prev_hash=None) self.tip = 12 + self.blockchain._size = self.tip + 1 async def get_block_header(self, height, assert_mode): assert self.q.qsize() > 0, (height, assert_mode) item = await self.q.get() @@ -56,7 +63,7 @@ class TestNetwork(unittest.TestCase): self.interface.q.put_nowait({'block_height': 5, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) self.interface.q.put_nowait({'block_height': 6, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) ifa = self.interface - self.assertEqual(('fork_noconflict', 8), asyncio.get_event_loop().run_until_complete(ifa.sync_until(8, next_height=7))) + self.assertEqual(('fork', 8), asyncio.get_event_loop().run_until_complete(ifa.sync_until(8, next_height=7))) self.assertEqual(self.interface.q.qsize(), 0) def test_fork_conflict(self): @@ -70,7 +77,7 @@ class TestNetwork(unittest.TestCase): self.interface.q.put_nowait({'block_height': 5, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) self.interface.q.put_nowait({'block_height': 6, 'mock': {'binary':1,'check':lambda x: True, 'connect': lambda x: True}}) ifa = self.interface - self.assertEqual(('fork_conflict', 8), asyncio.get_event_loop().run_until_complete(ifa.sync_until(8, next_height=7))) + self.assertEqual(('fork', 8), asyncio.get_event_loop().run_until_complete(ifa.sync_until(8, next_height=7))) self.assertEqual(self.interface.q.qsize(), 0) def test_can_connect_during_backward(self): @@ -87,7 +94,10 @@ class TestNetwork(unittest.TestCase): self.assertEqual(self.interface.q.qsize(), 0) def mock_fork(self, bad_header): - return blockchain.Blockchain(self.config, bad_header['block_height'], None) + forkpoint = bad_header['block_height'] + b = blockchain.Blockchain(config=self.config, forkpoint=forkpoint, parent=None, + forkpoint_hash=bh2u(sha256(str(forkpoint))), prev_hash=bh2u(sha256(str(forkpoint-1)))) + return b def test_chain_false_during_binary(self): blockchain.blockchains = {} diff --git a/electrum/tests/test_wallet.py b/electrum/tests/test_wallet.py index 9117392ea..c6366f3e3 100644 --- a/electrum/tests/test_wallet.py +++ b/electrum/tests/test_wallet.py @@ -64,7 +64,6 @@ class TestWalletStorage(WalletTestCase): storage.put(key, value) storage.write() - contents = "" with open(self.wallet_path, "r") as f: contents = f.read() self.assertEqual(some_dict, json.loads(contents)) From 65ce3deeaa33828407cf3a873c7ce5c48fa0b6d4 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Thu, 22 Nov 2018 17:13:43 +0100 Subject: [PATCH 2/2] blockchain: chain hierarchy based on most work, not length --- electrum/blockchain.py | 43 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/electrum/blockchain.py b/electrum/blockchain.py index c7c075312..d1238a2e8 100644 --- a/electrum/blockchain.py +++ b/electrum/blockchain.py @@ -141,6 +141,11 @@ def read_blockchains(config: 'SimpleConfig'): def get_best_chain() -> 'Blockchain': return blockchains[constants.net.GENESIS] +# block hash -> chain work; up to and including that block +_CHAINWORK_CACHE = { + "0000000000000000000000000000000000000000000000000000000000000000": 0, # virtual block at height -1 +} # type: Dict[str, int] + class Blockchain(util.PrintError): """ @@ -319,10 +324,10 @@ class Blockchain(util.PrintError): they will be stored in different files.""" if self.parent is None: return False - parent_branch_size = self.parent.height() - self.forkpoint + 1 - if parent_branch_size >= self.size(): # FIXME most work, not length + if self.parent.get_chainwork() >= self.get_chainwork(): return False self.print_error("swap", self.forkpoint, self.parent.forkpoint) + parent_branch_size = self.parent.height() - self.forkpoint + 1 forkpoint = self.forkpoint # type: Optional[int] parent = self.parent # type: Optional[Blockchain] child_old_id = self.get_id() @@ -481,6 +486,40 @@ class Blockchain(util.PrintError): bitsBase >>= 8 return bitsN << 24 | bitsBase + def chainwork_of_header_at_height(self, height: int) -> int: + """work done by single header at given height""" + chunk_idx = height // 2016 - 1 + target = self.get_target(chunk_idx) + work = ((2 ** 256 - target - 1) // (target + 1)) + 1 + return work + + @with_lock + def get_chainwork(self, height=None) -> int: + if height is None: + height = max(0, self.height()) + if constants.net.TESTNET: + # On testnet/regtest, difficulty works somewhat different. + # It's out of scope to properly implement that. + return height + last_retarget = height // 2016 * 2016 - 1 + cached_height = last_retarget + while _CHAINWORK_CACHE.get(self.get_hash(cached_height)) is None: + if cached_height <= -1: + break + cached_height -= 2016 + assert cached_height >= -1, cached_height + running_total = _CHAINWORK_CACHE[self.get_hash(cached_height)] + while cached_height < last_retarget: + cached_height += 2016 + work_in_single_header = self.chainwork_of_header_at_height(cached_height) + work_in_chunk = 2016 * work_in_single_header + running_total += work_in_chunk + _CHAINWORK_CACHE[self.get_hash(cached_height)] = running_total + cached_height += 2016 + work_in_single_header = self.chainwork_of_header_at_height(cached_height) + work_in_last_partial_chunk = (height % 2016 + 1) * work_in_single_header + return running_total + work_in_last_partial_chunk + def can_connect(self, header: dict, check_height: bool=True) -> bool: if header is None: return False