diff --git a/electrumx/lib/merkle.py b/electrumx/lib/merkle.py index 65c4089..f9bd505 100644 --- a/electrumx/lib/merkle.py +++ b/electrumx/lib/merkle.py @@ -187,24 +187,20 @@ class MerkleCache(object): # Retain the value of depth_higher; in practice this is fine start = self._leaf_start(self.length) hashes = self.source.hashes(start, length - start) - self.level[start >> self.depth_higher] = self._level(hashes) + self.level[start >> self.depth_higher:] = self._level(hashes) self.length = length def _level_for(self, length): '''Return a (level_length, final_hash) pair for a truncation - of the hashes to the given length. Length may be an extension, - in which case extra hashes are requested from the source.''' - if length > self.length: - hashes = self.source.hashes(self.length, length - self.length) - return self.level + self._level(hashes) - if length < self.length: - level = self.level[:length >> self.depth_higher] - leaf_start = self._leaf_start(length) - count = min(self._segment_length(), length - leaf_start) - hashes = self.source.hashes(leaf_start, count) - level += self._level(hashes) - return level - return self.level + of the hashes to the given length.''' + if length == self.length: + return self.level + level = self.level[:length >> self.depth_higher] + leaf_start = self._leaf_start(length) + count = min(self._segment_length(), length - leaf_start) + hashes = self.source.hashes(leaf_start, count) + level += self._level(hashes) + return level def branch_and_root(self, length, index): '''Return a merkle branch and root. Length is the number of diff --git a/tests/lib/test_merkle.py b/tests/lib/test_merkle.py index c860020..e5859a2 100644 --- a/tests/lib/test_merkle.py +++ b/tests/lib/test_merkle.py @@ -172,7 +172,7 @@ def test_merkle_cache(): assert root == root2 -def merkle_cache_extension(): +def test_merkle_cache_extension(): source = Source(64) for length in range(14, 18): for cp_length in range(30, 36): @@ -201,6 +201,7 @@ def test_markle_cache_bad(): with pytest.raises(ValueError): cache.branch_and_root(3, 3) + def test_bad_extension(): length = 5 source = Source(length)