SomberNight 7 years ago
parent
commit
afa4cbfcbb
  1. 2
      lib/plugins.py
  2. 42
      plugins/trezor/trezor.py

2
lib/plugins.py

@ -466,7 +466,7 @@ class DeviceMgr(ThreadJob, PrintError):
devices = [dev for dev in devices if not self.xpub_by_id(dev.id_)]
infos = []
for device in devices:
if not device.product_key in plugin.DEVICE_IDS:
if device.product_key not in plugin.DEVICE_IDS:
continue
client = self.create_client(device, handler, plugin)
if not client:

42
plugins/trezor/trezor.py

@ -126,7 +126,11 @@ class TrezorPlugin(HW_PluginBase):
self.device_manager().register_enumerate_func(self.enumerate)
def enumerate(self):
@staticmethod
def _all_transports():
"""Reimplemented trezorlib.transport.all_transports for old trezorlib.
Remove this when we start to require trezorlib 0.9.2
"""
try:
from trezorlib.transport import all_transports
except ImportError:
@ -154,9 +158,14 @@ class TrezorPlugin(HW_PluginBase):
except BaseException:
pass
return transports
return all_transports()
def _enumerate_devices(self):
"""Just like trezorlib.transport.enumerate_devices,
but with exception catching, so that transports can fail separately.
"""
devices = []
for transport in all_transports():
for transport in self._all_transports():
try:
new_devices = transport.enumerate()
except BaseException as e:
@ -164,14 +173,39 @@ class TrezorPlugin(HW_PluginBase):
.format(transport.__name__, str(e)))
else:
devices.extend(new_devices)
return devices
def enumerate(self):
devices = self._enumerate_devices()
return [Device(d.get_path(), -1, d.get_path(), 'TREZOR', 0) for d in devices]
def _get_transport(self, path=None):
"""Reimplemented trezorlib.transport.get_transport for old trezorlib.
Remove this when we start to require trezorlib 0.9.2
"""
try:
from trezorlib.transport import get_transport
except ImportError:
# compat for trezorlib < 0.9.2
def get_transport(path=None, prefix_search=False):
if path is None:
try:
return self._enumerate_devices()[0]
except IndexError:
raise Exception("No TREZOR device found") from None
def match_prefix(a, b):
return a.startswith(b) or b.startswith(a)
transports = [t for t in self._all_transports() if match_prefix(path, t.PATH_PREFIX)]
if transports:
return transports[0].find_by_path(path)
raise Exception("Unknown path prefix '%s'" % path)
return get_transport(path)
def create_client(self, device, handler):
from trezorlib.device import TrezorDevice
try:
self.print_error("connecting to device at", device.path)
transport = TrezorDevice.find_by_path(device.path)
transport = self._get_transport(device.path)
except BaseException as e:
self.print_error("cannot connect at", device.path, str(e))
return None

Loading…
Cancel
Save