diff --git a/lib/tx.py b/lib/tx.py index 8e1878f..79a6360 100644 --- a/lib/tx.py +++ b/lib/tx.py @@ -71,7 +71,8 @@ class TxOutput(namedtuple("TxOutput", "value pk_script")): class Deserializer(object): '''Deserializes blocks into transactions. - External entry points are read_tx() and read_block(). + External entry points are read_tx(), read_tx_and_hash(), + read_tx_and_vsize() and read_block(). This code is performance sensitive as it is executed 100s of millions of times during sync. @@ -84,25 +85,32 @@ class Deserializer(object): self.cursor = start def read_tx(self): - '''Return a (Deserialized TX, TX_HASH) pair. - - The hash needs to be reversed for human display; for efficiency - we process it in the natural serialized order. - ''' - start = self.cursor + '''Return a deserialized transaction.''' return Tx( self._read_le_int32(), # version self._read_inputs(), # inputs self._read_outputs(), # outputs self._read_le_uint32() # locktime - ), double_sha256(self.binary[start:self.cursor]) + ) + + def read_tx_and_hash(self): + '''Return a (deserialized TX, tx_hash) pair. + + The hash needs to be reversed for human display; for efficiency + we process it in the natural serialized order. + ''' + start = self.cursor + return self.read_tx(), double_sha256(self.binary[start:self.cursor]) + + def read_tx_and_vsize(self): + '''Return a (deserialized TX, vsize) pair.''' + return self.read_tx(), self.binary_length def read_tx_block(self): '''Returns a list of (deserialized_tx, tx_hash) pairs.''' - read_tx = self.read_tx - txs = [read_tx() for _ in range(self._read_varint())] + read = self.read_tx_and_hash # Some coins have excess data beyond the end of the transactions - return txs + return [read() for _ in range(self._read_varint())] def _read_inputs(self): read_input = self._read_input @@ -198,15 +206,12 @@ class DeserializerSegWit(Deserializer): read_varbytes = self._read_varbytes return [read_varbytes() for i in range(self._read_varint())] - def read_tx(self): - '''Return a (Deserialized TX, TX_HASH) pair. - - The hash needs to be reversed for human display; for efficiency - we process it in the natural serialized order. - ''' + def _read_tx_parts(self): + '''Return a (deserialized TX, tx_hash, vsize) tuple.''' marker = self.binary[self.cursor + 4] if marker: - return super().read_tx() + tx, tx_hash = super().read_tx_and_hash() + return tx, tx_hash, self.binary_size # Ugh, this is nasty. start = self.cursor @@ -221,14 +226,27 @@ class DeserializerSegWit(Deserializer): outputs = self._read_outputs() orig_ser += self.binary[start:self.cursor] + base_size = self.cursor - start witness = self._read_witness(len(inputs)) start = self.cursor locktime = self._read_le_uint32() orig_ser += self.binary[start:self.cursor] + vsize = (3 * base_size + self.binary_length) // 4 + + return TxSegWit(version, marker, flag, inputs, outputs, witness, + locktime), double_sha256(orig_ser), vsize + + def read_tx(self): + return self._read_tx_parts()[0] + + def read_tx_and_hash(self): + tx, tx_hash, vsize = self._read_tx_parts() + return tx, tx_hash - return TxSegWit(version, marker, flag, inputs, - outputs, witness, locktime), double_sha256(orig_ser) + def read_tx_and_vsize(self): + tx, tx_hash, vsize = self._read_tx_parts() + return tx, vsize class DeserializerAuxPow(Deserializer): @@ -289,7 +307,6 @@ class TxJoinSplit(namedtuple("Tx", "version inputs outputs locktime")): class DeserializerZcash(DeserializerEquihash): def read_tx(self): - start = self.cursor base_tx = TxJoinSplit( self._read_le_int32(), # version self._read_inputs(), # inputs @@ -302,7 +319,7 @@ class DeserializerZcash(DeserializerEquihash): self.cursor += joinsplit_size * 1802 # JSDescription self.cursor += 32 # joinSplitPubKey self.cursor += 64 # joinSplitSig - return base_tx, double_sha256(self.binary[start:self.cursor]) + return base_tx class TxTime(namedtuple("Tx", "version time inputs outputs locktime")): @@ -315,21 +332,17 @@ class TxTime(namedtuple("Tx", "version time inputs outputs locktime")): class DeserializerTxTime(Deserializer): def read_tx(self): - start = self.cursor - return TxTime( self._read_le_int32(), # version self._read_le_uint32(), # time self._read_inputs(), # inputs self._read_outputs(), # outputs self._read_le_uint32(), # locktime - ), double_sha256(self.binary[start:self.cursor]) + ) class DeserializerReddcoin(Deserializer): def read_tx(self): - start = self.cursor - version = self._read_le_int32() inputs = self._read_inputs() outputs = self._read_outputs() @@ -339,13 +352,7 @@ class DeserializerReddcoin(Deserializer): else: time = 0 - return TxTime( - version, - time, - inputs, - outputs, - locktime, - ), double_sha256(self.binary[start:self.cursor]) + return TxTime(version, time, inputs, outputs, locktime) class DeserializerTxTimeAuxPow(DeserializerTxTime): diff --git a/server/controller.py b/server/controller.py index e0c3428..fecb4b9 100644 --- a/server/controller.py +++ b/server/controller.py @@ -873,7 +873,7 @@ class Controller(ServerBase): if not raw_tx: return None raw_tx = util.hex_to_bytes(raw_tx) - tx, tx_hash = self.coin.DESERIALIZER(raw_tx).read_tx() + tx = self.coin.DESERIALIZER(raw_tx).read_tx() if index >= len(tx.outputs): return None return self.coin.address_from_script(tx.outputs[index].pk_script) diff --git a/server/mempool.py b/server/mempool.py index c5e35d3..219a38c 100644 --- a/server/mempool.py +++ b/server/mempool.py @@ -217,7 +217,7 @@ class MemPool(util.LoggedClass): for tx_hash, raw_tx in raw_tx_map.items(): if tx_hash not in txs: continue - tx, _tx_hash = deserializer(raw_tx).read_tx() + tx = deserializer(raw_tx).read_tx() # Convert the tx outputs into (hashX, value) pairs txout_pairs = [(script_hashX(txout.pk_script), txout.value) @@ -301,7 +301,7 @@ class MemPool(util.LoggedClass): txin_pairs, txout_pairs = item tx_fee = (sum(v for hashX, v in txin_pairs) - sum(v for hashX, v in txout_pairs)) - tx, tx_hash = deserializer(raw_tx).read_tx() + tx = deserializer(raw_tx).read_tx() unconfirmed = any(hash_to_str(txin.prev_hash) in self.txs for txin in tx.inputs) result.append((hex_hash, tx_fee, unconfirmed)) @@ -319,7 +319,7 @@ class MemPool(util.LoggedClass): for hex_hash, raw_tx in pairs: if not raw_tx: continue - tx, tx_hash = deserializer(raw_tx).read_tx() + tx = deserializer(raw_tx).read_tx() for txin in tx.inputs: spends.add((txin.prev_hash, txin.prev_idx)) return spends