Browse Source

Clean up trezor client interface

283
Neil Booth 9 years ago
parent
commit
43d21de1b2
  1. 20
      plugins/trezor/client.py
  2. 43
      plugins/trezor/plugin.py

20
plugins/trezor/client.py

@ -1,6 +1,7 @@
from sys import stderr from sys import stderr
from electrum.i18n import _ from electrum.i18n import _
from electrum.util import PrintError
class GuiMixin(object): class GuiMixin(object):
@ -58,18 +59,31 @@ class GuiMixin(object):
def trezor_client_class(protocol_mixin, base_client, proto): def trezor_client_class(protocol_mixin, base_client, proto):
'''Returns a class dynamically.''' '''Returns a class dynamically.'''
class TrezorClient(protocol_mixin, GuiMixin, base_client): class TrezorClient(protocol_mixin, GuiMixin, base_client, PrintError):
def __init__(self, transport, device): def __init__(self, transport, plugin):
base_client.__init__(self, transport) base_client.__init__(self, transport)
protocol_mixin.__init__(self, transport) protocol_mixin.__init__(self, transport)
self.proto = proto self.proto = proto
self.device = device self.device = plugin.device
self.handler = plugin.handler
self.tx_api = plugin
self.bad = False
def firmware_version(self):
f = self.features
v = (f.major_version, f.minor_version, f.patch_version)
self.print_error('firmware version', v)
return v
def atleast_version(self, major, minor=0, patch=0):
return cmp(self.firmware_version(), (major, minor, patch))
def call_raw(self, msg): def call_raw(self, msg):
try: try:
return base_client.call_raw(self, msg) return base_client.call_raw(self, msg)
except: except:
self.print_error("Marking %s client bad" % self.device)
self.bad = True self.bad = True
raise raise

43
plugins/trezor/plugin.py

@ -224,37 +224,33 @@ class TrezorCompatiblePlugin(BasePlugin):
return False return False
return True return True
def get_client(self): def create_client(self):
if not self.libraries_available: if not self.libraries_available:
self.give_error(_('please install the %s libraries from %s') self.give_error(_('please install the %s libraries from %s')
% (self.device, self.libraries_URL)) % (self.device, self.libraries_URL))
devices = self.HidTransport.enumerate()
if not devices:
self.give_error(_('Could not connect to your %s. Please '
'verify the cable is connected and that no '
'other app is using it.' % self.device))
transport = self.HidTransport(devices[0])
client = self.client_class(transport, self)
if not client.atleast_version(*self.minimum_firmware):
self.give_error(_('Outdated %s firmware. Please update the '
'firmware from %s')
% (self.device, self.firmware_URL))
return client
def get_client(self):
if not self.client or self.client.bad: if not self.client or self.client.bad:
d = self.HidTransport.enumerate() self.client = self.create_client()
if not d:
self.give_error(_('Could not connect to your %s. Please '
'verify the cable is connected and that no '
'other app is using it.' % self.device))
transport = self.HidTransport(d[0])
self.client = self.client_class(transport, self.device)
self.client.handler = self.handler
self.client.set_tx_api(self)
self.client.bad = False
if not self.atleast_version(*self.minimum_firmware):
self.client = None
self.give_error(_('Outdated %s firmware. Please update the '
'firmware from %s') % (self.device,
self.firmware_URL))
return self.client
def compare_version(self, major, minor=0, patch=0): return self.client
f = self.get_client().features
v = [f.major_version, f.minor_version, f.patch_version]
self.print_error('firmware version', v)
return cmp(v, [major, minor, patch])
def atleast_version(self, major, minor=0, patch=0): def atleast_version(self, major, minor=0, patch=0):
return self.compare_version(major, minor, patch) >= 0 return self.get_client().atleast_version(major, minor, patch)
@hook @hook
def close_wallet(self): def close_wallet(self):
@ -395,6 +391,7 @@ class TrezorCompatiblePlugin(BasePlugin):
o.script_pubkey = vout['scriptPubKey'].decode('hex') o.script_pubkey = vout['scriptPubKey'].decode('hex')
return t return t
# This function is called from the trezor libraries (via tx_api)
def get_tx(self, tx_hash): def get_tx(self, tx_hash):
tx = self.prev_tx[tx_hash] tx = self.prev_tx[tx_hash]
tx.deserialize() tx.deserialize()

Loading…
Cancel
Save