diff --git a/lib/blockchain.py b/lib/blockchain.py index d9434c3cf..e048a445e 100644 --- a/lib/blockchain.py +++ b/lib/blockchain.py @@ -62,26 +62,30 @@ def hash_header(header): class Blockchain(util.PrintError): + '''Manages blockchain headers and their verification''' - def __init__(self, config, checkpoint): + + def __init__(self, config, filename, fork_point): self.config = config - self.checkpoint = checkpoint - self.filename = 'blockchain_headers' if checkpoint == 0 else 'blockchain_fork_%d'%checkpoint - self.set_local_height() + self.filename = filename self.catch_up = None # interface catching up + if fork_point is None: + self.is_saved = True + self.checkpoint = int(filename[16:]) if filename.startswith('blockchain_fork_') else 0 + else: + self.is_saved = False + self.checkpoint = fork_point + self.headers = [] + self.set_local_height() def height(self): - return self.local_height + return self.local_height + len(self.headers) def verify_header(self, header, prev_header, bits, target): prev_hash = hash_header(prev_header) _hash = hash_header(header) if prev_hash != header.get('prev_block_hash'): raise BaseException("prev hash mismatch: %s vs %s" % (prev_hash, header.get('prev_block_hash'))) - #if not self.pass_checkpoint(header): - # raise BaseException('failed checkpoint') - #if self.checkpoint_height == header.get('block_height'): - # self.print_error("validated checkpoint", self.checkpoint_height) if bitcoin.TESTNET: return if bits != header.get('bits'): @@ -115,24 +119,50 @@ class Blockchain(util.PrintError): return os.path.join(d, self.filename) def save_chunk(self, index, chunk): + if not self.is_saved: + self.fork_and_save() filename = self.path() - f = open(filename, 'rb+') - f.seek(index * 2016 * 80) - f.truncate() - h = f.write(chunk) - f.close() + with open(filename, 'rb+') as f: + f.seek(index * 2016 * 80) + f.truncate() + h = f.write(chunk) self.set_local_height() + def fork_and_save(self): + import shutil + self.print_error("save fork") + height = self.checkpoint + filename = "blockchain_fork_%d"%height + new_path = os.path.join(util.get_headers_dir(self.config), filename) + shutil.copy(self.path(), new_path) + with open(new_path, 'rb+') as f: + f.seek((height) * 80) + f.truncate() + self.filename = filename + self.is_saved = True + for h in self.headers: + self.write_header(h) + self.headers = [] + def save_header(self, header): + height = header.get('block_height') + if not self.is_saved: + assert height == self.checkpoint + len(self.headers) + 1 + self.headers.append(header) + if len(self.headers) > 10: + self.fork_and_save() + return + self.write_header(header) + + def write_header(self, header): + height = header.get('block_height') data = serialize_header(header).decode('hex') assert len(data) == 80 - height = header.get('block_height') filename = self.path() - f = open(filename, 'rb+') - f.seek(height * 80) - f.truncate() - h = f.write(data) - f.close() + with open(filename, 'rb+') as f: + f.seek(height * 80) + f.truncate() + h = f.write(data) self.set_local_height() def set_local_height(self): @@ -143,15 +173,22 @@ class Blockchain(util.PrintError): if self.local_height != h: self.local_height = h - def read_header(self, block_height): + def read_header(self, height): + if not self.is_saved and height >= self.checkpoint: + i = height - self.checkpoint + if i >= len(self.headers): + return None + header = self.headers[i] + assert header.get('block_height') == height + return header name = self.path() if os.path.exists(name): f = open(name, 'rb') - f.seek(block_height * 80) + f.seek(height * 80) h = f.read(80) f.close() if len(h) == 80: - h = deserialize_header(h, block_height) + h = deserialize_header(h, height) return h def get_hash(self, height): @@ -173,17 +210,6 @@ class Blockchain(util.PrintError): f.truncate() f.close() - def fork(self, height): - import shutil - filename = "blockchain_fork_%d"%height - new_path = os.path.join(util.get_headers_dir(self.config), filename) - shutil.copy(self.path(), new_path) - with open(new_path, 'rb+') as f: - f.seek((height) * 80) - f.truncate() - f.close() - return filename - def get_target(self, index, chain=None): if bitcoin.TESTNET: return 0, 0 diff --git a/lib/network.py b/lib/network.py index 66d009feb..745df4844 100644 --- a/lib/network.py +++ b/lib/network.py @@ -205,13 +205,12 @@ class Network(util.DaemonThread): config = {} # Do not use mutables as default values! util.DaemonThread.__init__(self) self.config = SimpleConfig(config) if type(config) == type({}) else config - self.num_server = 8 if not self.config.get('oneserver') else 0 - self.blockchains = { 0:Blockchain(self.config, 0) } + self.num_server = 18 if not self.config.get('oneserver') else 0 + self.blockchains = { 0:Blockchain(self.config, 'blockchain_headers', None) } for x in os.listdir(self.config.path): if x.startswith('blockchain_fork_'): - n = int(x[16:]) - b = Blockchain(self.config, n) - self.blockchains[n] = b + b = Blockchain(self.config, x, None) + self.blockchains[b.checkpoint] = b self.print_error("blockchains", self.blockchains.keys()) self.blockchain_index = config.get('blockchain_index', 0) if self.blockchain_index not in self.blockchains.keys(): @@ -864,23 +863,16 @@ class Network(util.DaemonThread): if interface.bad != interface.good + 1: next_height = (interface.bad + interface.good) // 2 else: - interface.print_error("found connection at %d"% interface.good) delta1 = interface.blockchain.height() - interface.good delta2 = interface.tip - interface.good - threshold = self.config.get('fork_threshold', 5) - if delta1 > threshold and delta2 > threshold: - interface.print_error("chain split detected: %d (%d %d)"% (interface.good, delta1, delta2)) - interface.blockchain.fork(interface.bad) - interface.blockchain = Blockchain(self.config, interface.bad) - self.blockchains[interface.bad] = interface.blockchain - if interface.blockchain.catch_up is None: - interface.blockchain.catch_up = interface.server - interface.print_error("catching up") - interface.mode = 'catch_up' - next_height = interface.good - else: - # todo: if current catch_up is too slow, queue others - next_height = None + interface.print_error("chain split detected at %d"%interface.good, delta1, delta2) + interface.blockchain = Blockchain(self.config, False, interface.bad) + interface.blockchain.catch_up = interface.server + self.blockchains[interface.bad] = interface.blockchain + interface.print_error("catching up") + interface.mode = 'catch_up' + next_height = interface.good + elif interface.mode == 'catch_up': if can_connect: interface.blockchain.save_header(header)