From d390b38acfeaf9ce24b473f3c5cce92fd6763603 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Mon, 16 Jul 2018 12:42:53 +0800 Subject: [PATCH] Add cache truncation and tests --- electrumx/lib/merkle.py | 12 ++++++++++++ tests/lib/test_merkle.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/electrumx/lib/merkle.py b/electrumx/lib/merkle.py index f9bd505..1d235f4 100644 --- a/electrumx/lib/merkle.py +++ b/electrumx/lib/merkle.py @@ -202,6 +202,18 @@ class MerkleCache(object): level += self._level(hashes) return level + def truncate(self, length): + '''Truncate the cache so it is no longer than length.''' + if not isinstance(length, int): + raise TypeError('length must be an integer') + if length <= 0: + raise ValueError('length must be positive') + if length >= self.length: + return + length = self._leaf_start(length) + self.length = length + self.level[length >> self.depth_higher:] = [] + def branch_and_root(self, length, index): '''Return a merkle branch and root. Length is the number of hashes used to calculate the merkle root, index is the position diff --git a/tests/lib/test_merkle.py b/tests/lib/test_merkle.py index e5859a2..af095d4 100644 --- a/tests/lib/test_merkle.py +++ b/tests/lib/test_merkle.py @@ -187,6 +187,41 @@ def test_merkle_cache_extension(): assert root == root2 +def test_merkle_cache_truncation(): + max_length = 33 + source = Source(max_length) + for length in range(max_length - 2, max_length + 1): + for trunc_length in range(1, 20, 3): + cache = MerkleCache(merkle, source, length) + cache.truncate(trunc_length) + assert cache.length <= trunc_length + for cp_length in range(1, length + 1, 3): + cp_hashes = source.hashes(0, cp_length) + # All possible indices + for index in range(cp_length): + # Compare correct answer with cache + branch, root = merkle.branch_and_root(cp_hashes, index) + branch2, root2 = cache.branch_and_root(cp_length, index) + assert branch == branch2 + assert root == root2 + + # Truncation is a no-op if longer + cache = MerkleCache(merkle, source, 10) + level = cache.level.copy() + for length in range(10, 13): + cache.truncate(length) + assert cache.level == level + assert cache.length == 10 + +def test_truncation_bad(): + cache = MerkleCache(merkle, Source(10), 10) + with pytest.raises(TypeError): + cache.truncate(1.0) + for n in (-1, 0): + with pytest.raises(ValueError): + cache.truncate(n) + + def test_markle_cache_bad(): length = 23 source = Source(length) @@ -206,7 +241,7 @@ def test_bad_extension(): length = 5 source = Source(length) cache = MerkleCache(merkle, source, length) - level = cache.level + level = cache.level.copy() with pytest.raises(AssertionError): cache.branch_and_root(8, 0) # The bad extension should not destroy the cache