Browse Source

Implement a markle cache with tests

patch-2
Neil Booth 7 years ago
parent
commit
4168341857
  1. 81
      electrumx/lib/merkle.py
  2. 177
      tests/lib/test_merkle.py

81
electrumx/lib/merkle.py

@ -121,7 +121,8 @@ class Merkle(object):
return [root(hashes[n: n + size], depth_higher)
for n in range(0, len(hashes), size)]
def branch_from_level(self, level, leaf_hashes, index, depth_higher):
def branch_and_root_from_level(self, level, leaf_hashes, index,
depth_higher):
'''Return a (merkle branch, merkle_root) pair when a merkle-tree has a
level cached.
@ -142,7 +143,7 @@ class Merkle(object):
if not isinstance(level, list):
raise TypeError("level must be a list")
if not isinstance(leaf_hashes, list):
raise TypeError("level must be a list")
raise TypeError("leaf_hashes must be a list")
leaf_index = (index >> depth_higher) << depth_higher
leaf_branch, leaf_root = self.branch_and_root(
leaf_hashes, index - leaf_index, depth_higher)
@ -152,3 +153,79 @@ class Merkle(object):
if leaf_root != level[index]:
raise ValueError('leaf hashes inconsistent with level')
return leaf_branch + level_branch, root
class MerkleCache(object):
'''A cache to calculate merkle branches efficiently.'''
def __init__(self, merkle, source, length):
'''Initialise a cache of length hashes taken from source.'''
self.merkle = merkle
self.source = source
self.length = length
self.depth_higher = merkle.tree_depth(length) // 2
self.level = self._level(source.hashes(0, length))
def _segment_length(self):
return 1 << self.depth_higher
def _leaf_start(self, index):
'''Given a level's depth higher and a hash index, return the leaf
index and leaf hash count needed to calculate a merkle branch.
'''
depth_higher = self.depth_higher
return (index >> depth_higher) << depth_higher
def _level(self, hashes):
return self.merkle.level(hashes, self.depth_higher)
def _extend_to(self, length):
'''Extend the length of the cache if necessary.'''
if length <= self.length:
return
# Start from the beginning of any final partial segment.
# Retain the value of depth_higher; in practice this is fine
start = self._leaf_start(self.length)
hashes = self.source.hashes(start, length - start)
self.level[start >> self.depth_higher] = self._level(hashes)
self.length = length
def _level_for(self, length):
'''Return a (level_length, final_hash) pair for a truncation
of the hashes to the given length. Length may be an extension,
in which case extra hashes are requested from the source.'''
if length > self.length:
hashes = self.source.hashes(self.length, length - self.length)
return self.level + self._level(hashes)
if length < self.length:
level = self.level[:length >> self.depth_higher]
leaf_start = self._leaf_start(length)
count = min(self._segment_length(), length - leaf_start)
hashes = self.source.hashes(leaf_start, count)
level += self._level(hashes)
return level
return self.level
def branch_and_root(self, length, index):
'''Return a merkle branch and root. Length is the number of
hashes used to calculate the merkle root, index is the position
of the hash to calculate the branch of.
index must be less than length, which must be at least 1.'''
if not isinstance(length, int):
raise TypeError('length must be an integer')
if not isinstance(index, int):
raise TypeError('index must be an integer')
if length <= 0:
raise ValueError('length must be positive')
if index >= length:
raise ValueError('index must be less than length')
self._extend_to(length)
leaf_start = self._leaf_start(index)
count = min(self._segment_length(), length - leaf_start)
leaf_hashes = self.source.hashes(leaf_start, count)
if length < self._segment_length():
return self.merkle.branch_and_root(leaf_hashes, index)
level = self._level_for(length)
return self.merkle.branch_and_root_from_level(
level, leaf_hashes, index, self.depth_higher)

177
tests/lib/test_merkle.py

@ -1,10 +1,11 @@
import os
import pytest
from electrumx.lib.merkle import Merkle
from electrumx.lib.merkle import Merkle, MerkleCache
Merkle = Merkle()
hashes = [Merkle.hash_func(bytes([x])) for x in range(8)]
merkle = Merkle()
hashes = [merkle.hash_func(bytes([x])) for x in range(8)]
roots = [
b'\x14\x06\xe0X\x81\xe2\x996wf\xd3\x13\xe2l\x05VN\xc9\x1b\xf7!\xd3\x17&\xbdnF\xe6\x06\x89S\x9a',
b'K\xbe\x83\xbc8\xeb\xe2\xbc\xc7R\r#A9\xdf\x1c\x0e\xb9\xff\xa5\x1f\x83\xea\xb1\xc5\x12\x9b[\x90kvU',
@ -19,125 +20,209 @@ roots = [
def test_branch_length():
assert Merkle.branch_length(1) == 0
assert Merkle.branch_length(2) == 1
assert merkle.branch_length(1) == 0
assert merkle.branch_length(2) == 1
for n in range(3, 5):
assert Merkle.branch_length(n) == 2
assert merkle.branch_length(n) == 2
for n in range(5, 9):
assert Merkle.branch_length(n) == 3
assert merkle.branch_length(n) == 3
def test_branch_length_bad():
with pytest.raises(TypeError):
Merkle.branch_length(1.0)
merkle.branch_length(1.0)
for n in (-1, 0):
with pytest.raises(ValueError):
Merkle.branch_length(n)
merkle.branch_length(n)
def test_tree_depth():
for n in range(1, 10):
assert Merkle.tree_depth(n) == Merkle.branch_length(n) + 1
assert merkle.tree_depth(n) == merkle.branch_length(n) + 1
def test_root():
for n in range(len(hashes)):
assert Merkle.root(hashes[:n + 1]) == roots[n]
assert merkle.root(hashes[:n + 1]) == roots[n]
def test_root_bad():
with pytest.raises(TypeError):
Merkle.root(0)
merkle.root(0)
with pytest.raises(ValueError):
Merkle.root([])
merkle.root([])
def test_branch_and_root_from_proof():
for n in range(len(hashes)):
for m in range(n + 1):
branch, root = Merkle.branch_and_root(hashes[:n + 1], m)
branch, root = merkle.branch_and_root(hashes[:n + 1], m)
assert root == roots[n]
root = Merkle.root_from_proof(hashes[m], branch, m)
root = merkle.root_from_proof(hashes[m], branch, m)
assert root == roots[n]
def test_branch_bad():
with pytest.raises(TypeError):
Merkle.branch_and_root(0, 0)
merkle.branch_and_root(0, 0)
with pytest.raises(ValueError):
Merkle.branch_and_root([], 0)
merkle.branch_and_root([], 0)
with pytest.raises(TypeError):
Merkle.branch_and_root(hashes, 0.0)
merkle.branch_and_root(hashes, 0.0)
with pytest.raises(ValueError):
Merkle.branch_and_root(hashes[:2], -1)
merkle.branch_and_root(hashes[:2], -1)
with pytest.raises(ValueError):
Merkle.branch_and_root(hashes[:2], 2)
Merkle.branch_and_root(hashes, 0, 3)
merkle.branch_and_root(hashes[:2], 2)
merkle.branch_and_root(hashes, 0, 3)
with pytest.raises(TypeError):
Merkle.branch_and_root(hashes, 0, 3.0)
merkle.branch_and_root(hashes, 0, 3.0)
with pytest.raises(ValueError):
Merkle.branch_and_root(hashes, 0, 2)
merkle.branch_and_root(hashes, 0, 2)
def test_root_from_proof_bad():
with pytest.raises(TypeError):
Merkle.root_from_proof(0, hashes[:2], 0)
merkle.root_from_proof(0, hashes[:2], 0)
with pytest.raises(TypeError):
Merkle.root_from_proof(hashes[0], hashes[0], 0)
merkle.root_from_proof(hashes[0], hashes[0], 0)
with pytest.raises(ValueError):
Merkle.root_from_proof(hashes[0], hashes[:3], -1)
merkle.root_from_proof(hashes[0], hashes[:3], -1)
with pytest.raises(ValueError):
Merkle.root_from_proof(hashes[0], hashes[:3], 8)
merkle.root_from_proof(hashes[0], hashes[:3], 8)
def test_level():
for n in range(len(hashes)):
depth = Merkle.tree_depth(n + 1)
depth = merkle.tree_depth(n + 1)
for depth_higher in range(0, depth):
level = Merkle.level(hashes[:n + 1], depth_higher)
level = merkle.level(hashes[:n + 1], depth_higher)
if depth_higher == 0:
assert level == hashes[:n + 1]
if depth_higher == depth:
assert level == [roots[n]]
# Check raising from level to root works
assert Merkle.root(level) == roots[n]
assert merkle.root(level) == roots[n]
def test_branch_from_level():
def test_branch_and_root_from_level():
# For all sub-trees
for n in range(0, len(hashes)):
part = hashes[:n + 1]
# For all depths in sub-tree
for depth_higher in range(0, Merkle.tree_depth(len(part))):
level = Merkle.level(part, depth_higher)
for depth_higher in range(0, merkle.tree_depth(len(part))):
level = merkle.level(part, depth_higher)
# For each hash in sub-tree
for index, hash in enumerate(part):
leaf_index = (index >> depth_higher) << depth_higher
leaf_hashes = part[leaf_index:
leaf_index + (1 << depth_higher)]
branch = Merkle.branch_and_root(part, index)
branch2 = Merkle.branch_from_level(level, leaf_hashes,
index, depth_higher)
branch = merkle.branch_and_root(part, index)
branch2 = merkle.branch_and_root_from_level(
level, leaf_hashes, index, depth_higher)
assert branch == branch2
def test_branch_from_level_bad():
def test_branch_and_root_from_level_bad():
with pytest.raises(TypeError):
Merkle.branch_from_level(hashes[0], hashes, 0, 0)
merkle.branch_and_root_from_level(hashes[0], hashes, 0, 0)
with pytest.raises(TypeError):
Merkle.branch_from_level(hashes, hashes[0], 0, 0)
Merkle.branch_from_level(hashes, [hashes[0]], 0, 0)
merkle.branch_and_root_from_level(hashes, hashes[0], 0, 0)
merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, 0)
with pytest.raises(ValueError):
Merkle.branch_from_level(hashes, [hashes[0]], -1, 0)
merkle.branch_and_root_from_level(hashes, [hashes[0]], -1, 0)
with pytest.raises(TypeError):
Merkle.branch_from_level(hashes, hashes, 0.0, 0)
merkle.branch_and_root_from_level(hashes, hashes, 0.0, 0)
with pytest.raises(ValueError):
Merkle.branch_from_level(hashes, [hashes[0]], 0, -1)
merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, -1)
with pytest.raises(ValueError):
Merkle.branch_from_level(hashes, [hashes[0]], 0, 1)
merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, 1)
with pytest.raises(ValueError):
# Inconsistent hash
Merkle.branch_from_level(hashes, [hashes[1]], 0, 0)
merkle.branch_and_root_from_level(hashes, [hashes[1]], 0, 0)
with pytest.raises(ValueError):
# Inconsistent hash
Merkle.branch_from_level(hashes, [hashes[0]], 1, 0)
merkle.branch_and_root_from_level(hashes, [hashes[0]], 1, 0)
class Source(object):
def __init__(self, length):
self._hashes = [os.urandom(32) for _ in range(length)]
def hashes(self, start, length):
assert start >= 0
assert start + length <= len(self._hashes)
return self._hashes[start: start + length]
def test_merkle_cache():
lengths = (*range(1, 18), 31, 32, 33, 57)
source = Source(max(lengths))
for length in lengths:
cache = MerkleCache(merkle, source, length)
# Simulate all possible checkpoints
for cp_length in range(1, length + 1):
cp_hashes = source.hashes(0, cp_length)
# All possible indices
for index in range(cp_length):
# Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index)
assert branch == branch2
assert root == root2
def merkle_cache_extension():
source = Source(64)
for length in range(14, 18):
for cp_length in range(30, 36):
cache = MerkleCache(merkle, source, length)
cp_hashes = source.hashes(0, cp_length)
# All possible indices
for index in range(cp_length):
# Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index)
assert branch == branch2
assert root == root2
def test_markle_cache_bad():
length = 23
source = Source(length)
cache = MerkleCache(merkle, source, length)
cache.branch_and_root(5, 3)
with pytest.raises(TypeError):
cache.branch_and_root(5.0, 3)
with pytest.raises(TypeError):
cache.branch_and_root(5, 3.0)
with pytest.raises(ValueError):
cache.branch_and_root(0, -1)
with pytest.raises(ValueError):
cache.branch_and_root(3, 3)
def test_bad_extension():
length = 5
source = Source(length)
cache = MerkleCache(merkle, source, length)
level = cache.level
with pytest.raises(AssertionError):
cache.branch_and_root(8, 0)
# The bad extension should not destroy the cache
assert cache.level == level
assert cache.length == length
def time_it():
source = Source(500000)
import time
cache = MerkleCache(merkle, source)
cp_length = 492000
cp_hashes = source.hashes(0, cp_length)
brs2 = []
t1 = time.time()
for index in range(5, 400000, 500):
brs2.append(cache.branch_and_root(cp_length, index))
t2 = time.time()
print(t2 - t1)
assert False

Loading…
Cancel
Save