From 1387a4590852efe0ae07295592c87de9f38872be Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 9 Jan 2018 17:09:58 +0100 Subject: [PATCH] trezor plugin: native segwit and bip84 --- plugins/trezor/plugin.py | 64 ++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/plugins/trezor/plugin.py b/plugins/trezor/plugin.py index b0f4d1681..a6116c639 100644 --- a/plugins/trezor/plugin.py +++ b/plugins/trezor/plugin.py @@ -16,13 +16,27 @@ from ..hw_wallet import HW_PluginBase # TREZOR initialization methods TIM_NEW, TIM_RECOVER, TIM_MNEMONIC, TIM_PRIVKEY = range(0, 4) +# script "generation" +SCRIPT_GEN_LEGACY, SCRIPT_GEN_P2SH_SEGWIT, SCRIPT_GEN_NATIVE_SEGWIT = range(0, 3) + class TrezorCompatibleKeyStore(Hardware_KeyStore): def get_derivation(self): return self.derivation - def is_segwit(self): - return self.derivation.startswith("m/49'/") + def get_script_gen(self): + def is_p2sh_segwit(): + return self.derivation.startswith("m/49'/") + + def is_native_segwit(): + return self.derivation.startswith("m/84'/") + + if is_native_segwit(): + return SCRIPT_GEN_NATIVE_SEGWIT + elif is_p2sh_segwit(): + return SCRIPT_GEN_P2SH_SEGWIT + else: + return SCRIPT_GEN_LEGACY def get_client(self, force_pair=True): return self.plugin.get_client(self, force_pair) @@ -226,8 +240,8 @@ class TrezorCompatiblePlugin(HW_PluginBase): self.prev_tx = prev_tx self.xpub_path = xpub_path client = self.get_client(keystore) - inputs = self.tx_inputs(tx, True, keystore.is_segwit()) - outputs = self.tx_outputs(keystore.get_derivation(), tx, keystore.is_segwit()) + inputs = self.tx_inputs(tx, True, keystore.get_script_gen()) + outputs = self.tx_outputs(keystore.get_derivation(), tx, keystore.get_script_gen()) signed_tx = client.sign_tx(self.get_coin_name(), inputs, outputs, lock_time=tx.locktime)[1] raw = bh2u(signed_tx) tx.update_signatures(raw) @@ -241,11 +255,16 @@ class TrezorCompatiblePlugin(HW_PluginBase): derivation = wallet.keystore.derivation address_path = "%s/%d/%d"%(derivation, change, index) address_n = client.expand_path(address_path) - segwit = wallet.keystore.is_segwit() - script_type = self.types.InputScriptType.SPENDP2SHWITNESS if segwit else self.types.InputScriptType.SPENDADDRESS + script_gen = wallet.keystore.get_script_gen() + if script_gen == SCRIPT_GEN_NATIVE_SEGWIT: + script_type = self.types.InputScriptType.SPENDWITNESS + elif script_gen == SCRIPT_GEN_P2SH_SEGWIT: + script_type = self.types.InputScriptType.SPENDP2SHWITNESS + else: + script_type = self.types.InputScriptType.SPENDADDRESS client.get_address(self.get_coin_name(), address_n, True, script_type=script_type) - def tx_inputs(self, tx, for_sig=False, segwit=False): + def tx_inputs(self, tx, for_sig=False, script_gen=SCRIPT_GEN_LEGACY): inputs = [] for txin in tx.inputs(): txinputtype = self.types.TxInputType() @@ -260,7 +279,12 @@ class TrezorCompatiblePlugin(HW_PluginBase): xpub, s = parse_xpubkey(x_pubkey) xpub_n = self.client_class.expand_path(self.xpub_path[xpub]) txinputtype._extend_address_n(xpub_n + s) - txinputtype.script_type = self.types.InputScriptType.SPENDP2SHWITNESS if segwit else self.types.InputScriptType.SPENDADDRESS + if script_gen == SCRIPT_GEN_NATIVE_SEGWIT: + txinputtype.script_type = self.types.InputScriptType.SPENDWITNESS + elif script_gen == SCRIPT_GEN_P2SH_SEGWIT: + txinputtype.script_type = self.types.InputScriptType.SPENDP2SHWITNESS + else: + txinputtype.script_type = self.types.InputScriptType.SPENDADDRESS else: def f(x_pubkey): if is_xpubkey(x_pubkey): @@ -276,7 +300,12 @@ class TrezorCompatiblePlugin(HW_PluginBase): signatures=map(lambda x: bfh(x)[:-1] if x else b'', txin.get('signatures')), m=txin.get('num_sig'), ) - script_type = self.types.InputScriptType.SPENDP2SHWITNESS if segwit else self.types.InputScriptType.SPENDMULTISIG + if script_gen == SCRIPT_GEN_NATIVE_SEGWIT: + script_type = self.types.InputScriptType.SPENDWITNESS + elif script_gen == SCRIPT_GEN_P2SH_SEGWIT: + script_type = self.types.InputScriptType.SPENDP2SHWITNESS + else: + script_type = self.types.InputScriptType.SPENDMULTISIG txinputtype = self.types.TxInputType( script_type=script_type, multisig=multisig @@ -308,7 +337,7 @@ class TrezorCompatiblePlugin(HW_PluginBase): return inputs - def tx_outputs(self, derivation, tx, segwit=False): + def tx_outputs(self, derivation, tx, script_gen=SCRIPT_GEN_LEGACY): outputs = [] has_change = False @@ -316,10 +345,14 @@ class TrezorCompatiblePlugin(HW_PluginBase): info = tx.output_info.get(address) if info is not None and not has_change: has_change = True # no more than one change address - addrtype, hash_160 = b58_address_to_hash160(address) index, xpubs, m = info if len(xpubs) == 1: - script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS if segwit else self.types.OutputScriptType.PAYTOADDRESS + if script_gen == SCRIPT_GEN_NATIVE_SEGWIT: + script_type = self.types.OutputScriptType.PAYTOWITNESS + elif script_gen == SCRIPT_GEN_P2SH_SEGWIT: + script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS + else: + script_type = self.types.OutputScriptType.PAYTOADDRESS address_n = self.client_class.expand_path(derivation + "/%d/%d"%index) txoutputtype = self.types.TxOutputType( amount = amount, @@ -327,7 +360,12 @@ class TrezorCompatiblePlugin(HW_PluginBase): address_n = address_n, ) else: - script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS if segwit else self.types.OutputScriptType.PAYTOMULTISIG + if script_gen == SCRIPT_GEN_NATIVE_SEGWIT: + script_type = self.types.OutputScriptType.PAYTOWITNESS + elif script_gen == SCRIPT_GEN_P2SH_SEGWIT: + script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS + else: + script_type = self.types.OutputScriptType.PAYTOMULTISIG address_n = self.client_class.expand_path("/%d/%d"%index) nodes = map(self.ckd_public.deserialize, xpubs) pubkeys = [ self.types.HDNodePathType(node=node, address_n=address_n) for node in nodes]