From 24037be99c8089fce89bd89a858513dfc4f3ee2d Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Wed, 20 Jan 2016 00:28:54 +0900 Subject: [PATCH] Clean up client caching and handling --- lib/plugins.py | 230 ++++++++++++++++------------------- plugins/keepkey/client.py | 4 +- plugins/trezor/client.py | 4 +- plugins/trezor/clientbase.py | 11 +- plugins/trezor/plugin.py | 121 +++++++++--------- plugins/trezor/qt_generic.py | 31 ++--- 6 files changed, 193 insertions(+), 208 deletions(-) diff --git a/lib/plugins.py b/lib/plugins.py index 40d39f29c..97698a3a9 100644 --- a/lib/plugins.py +++ b/lib/plugins.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from collections import namedtuple import traceback import sys import os @@ -226,6 +227,7 @@ class BasePlugin(PrintError): def settings_dialog(self): pass +Device = namedtuple("Device", "path id_ product_key") class DeviceMgr(PrintError): '''Manages hardware clients. A client communicates over a hardware @@ -262,82 +264,115 @@ class DeviceMgr(PrintError): def __init__(self): super(DeviceMgr, self).__init__() - # Keyed by wallet. The value is the hid_id if the wallet has - # been paired, and None otherwise. + # Keyed by wallet. The value is the device id if the wallet + # has been paired, and None otherwise. self.wallets = {} - # A list of clients. We create a client for every device present - # that is of a registered hardware type - self.clients = [] - # What we recognise. Keyed by (vendor_id, product_id) pairs, - # the value is a callback to create a client for those devices - self.recognised_hardware = {} + # A list of clients. The key is the client, the value is + # a (path, id_) pair. + self.clients = {} + # What we recognise. Each entry is a (vendor_id, product_id) + # pair. + self.recognised_hardware = set() # For synchronization self.lock = threading.RLock() - def register_devices(self, device_pairs, create_client): + def register_devices(self, device_pairs): for pair in device_pairs: - self.recognised_hardware[pair] = create_client + self.recognised_hardware.add(pair) - def unpair(self, hid_id): + def create_client(self, device, handler, plugin): + client = plugin.create_client(device, handler) + if client: + self.print_error("Registering", client) + with self.lock: + self.clients[client] = (device.path, device.id_) + return client + + def wallet_id(self, wallet): with self.lock: - wallet = self.wallet_by_hid_id(hid_id) - if wallet: - self.wallets[wallet] = None + return self.wallets.get(wallet) - def close_client(self, client): + def wallet_by_id(self, id_): with self.lock: - if client in self.clients: - self.clients.remove(client) + for wallet, wallet_id in self.wallets.items(): + if wallet_id == id_: + return wallet + return None + + def unpair_wallet(self, wallet): + with self.lock: + wallet_id = self.wallets.pop(wallet) + client = self.client_lookup(wallet_id) + self.clients.pop(client, None) + wallet.unpaired() if client: client.close() - def close_wallet(self, wallet): - # Remove the wallet from our list; close any client + def unpair_id(self, id_): with self.lock: - hid_id = self.wallets.pop(wallet, None) - self.close_client(self.client_by_hid_id(hid_id)) + wallet = self.wallet_by_id(id_) + if wallet: + self.unpair_wallet(wallet) - def unpaired_clients(self, handler, classinfo): - '''Returns all unpaired clients of the given type.''' - self.scan_devices(handler) + def pair_wallet(self, wallet, id_): with self.lock: - return [client for client in self.clients - if isinstance(client, classinfo) - and not self.wallet_by_hid_id(client.hid_id())] + self.wallets[wallet] = id_ + wallet.paired() - def client_by_hid_id(self, hid_id, handler=None): - '''Like get_client() but when we don't care about wallet pairing. If - a device is wiped or in bootloader mode pairing is impossible; - in such cases we communicate by device ID and not wallet.''' - if handler: - self.scan_devices(handler) + def paired_wallets(self): + return list(self.wallets.keys()) + + def client_lookup(self, id_): with self.lock: - for client in self.clients: - if client.hid_id() == hid_id: + for client, (path, client_id) in self.clients.items(): + if client_id == id_: return client - return None + return None - def wallet_hid_id(self, wallet): - with self.lock: - return self.wallets.get(wallet) + def client_by_id(self, id_, handler): + '''Returns a client for the device ID if one is registered. If + a device is wiped or in bootloader mode pairing is impossible; + in such cases we communicate by device ID and not wallet.''' + self.scan_devices(handler) + return self.client_lookup(id_) - def wallet_by_hid_id(self, hid_id): - with self.lock: - for wallet, wallet_hid_id in self.wallets.items(): - if wallet_hid_id == hid_id: - return wallet - return None + def client_for_wallet(self, plugin, wallet, force_pair): + assert wallet.handler - def paired_wallets(self): - with self.lock: - return [wallet for (wallet, hid_id) in self.wallets.items() - if hid_id is not None] + devices = self.scan_devices(wallet.handler) + wallet_id = self.wallet_id(wallet) + + client = self.client_lookup(wallet_id) + if client: + return client + + for device in devices: + if device.id_ == wallet_id: + return self.create_client(device, wallet.handler, plugin) + + if force_pair: + first_address, derivation = wallet.first_address() + # Wallets don't have a first address in the install wizard + # until account creation + if not first_address: + self.print_error("no first address for ", wallet) + return None + + # The wallet has not been previously paired, so get the + # first address of all unpaired clients and compare. + for device in devices: + # Skip already-paired devices + if self.wallet_by_id(device.id_): + continue + client = self.create_client(device, wallet.handler, plugin) + if client and not client.features.bootloader_mode: + # This will trigger a PIN/passphrase entry request + client_first_address = client.first_address(derivation) + if client_first_address == first_address: + self.pair_wallet(wallet, device.id_) + return client - def pair_wallet(self, wallet, client): - assert client in self.clients - self.print_error("paired:", wallet, client) - self.wallets[wallet] = client.hid_id() - wallet.connected() + return None def scan_devices(self, handler): # All currently supported hardware libraries use hid, so we @@ -349,76 +384,27 @@ class DeviceMgr(PrintError): self.print_error("scanning devices...") # First see what's connected that we know about - devices = {} + devices = [] for d in hid.enumerate(0, 0): product_key = (d['vendor_id'], d['product_id']) - create_client = self.recognised_hardware.get(product_key) - if create_client: - devices[d['serial_number']] = (create_client, d['path']) + if product_key in self.recognised_hardware: + devices.append(Device(d['path'], d['serial_number'], + product_key)) # Now find out what was disconnected + pairs = [(dev.path, dev.id_) for dev in devices] + disconnected_ids = [] with self.lock: - disconnected = [client for client in self.clients - if not client.hid_id() in devices] - - # Close disconnected clients after informing their wallets - for client in disconnected: - wallet = self.wallet_by_hid_id(client.hid_id()) - if wallet: - wallet.disconnected() - self.close_client(client) - - # Now see if any new devices are present. - for hid_id, (create_client, path) in devices.items(): - try: - client = create_client(path, handler, hid_id) - except BaseException as e: - self.print_error("could not create client", str(e)) - client = None - if client: - self.print_error("client created for", path) - with self.lock: - self.clients.append(client) - # Inform re-paired wallet - wallet = self.wallet_by_hid_id(hid_id) - if wallet: - self.pair_wallet(wallet, client) - - def get_client(self, wallet, force_pair=True): - '''Returns a client for the wallet, or None if one could not be found. - If force_pair is False then if an already paired client cannot - be found None is returned rather than requiring user - interaction.''' - # We must scan devices to get an up-to-date idea of which - # devices are present. Operating on a client when its device - # has been removed can cause the process to hang. - # Unfortunately there is no plugged / unplugged notification - # system. - self.scan_devices(wallet.handler) - - # Previously paired wallets only need look for matching HID IDs - hid_id = self.wallet_hid_id(wallet) - if hid_id: - return self.client_by_hid_id(hid_id) - - first_address, derivation = wallet.first_address() - # Wallets don't have a first address in the install wizard - # until account creation - if not first_address: - self.print_error("no first address for ", wallet) - return None - - with self.lock: - # The wallet has not been previously paired, so get the - # first address of all unpaired clients and compare. - for client in self.clients: - # If already paired skip it - if self.wallet_by_hid_id(client.hid_id()): - continue - # This will trigger a PIN/passphrase entry request - if client.first_address(derivation) == first_address: - self.pair_wallet(wallet, client) - return client - - # Not found - return None + connected = {} + for client, pair in self.clients.items(): + if pair in pairs: + connected[client] = pair + else: + disconnected_ids.append(pair[1]) + self.clients = connected + + # Unpair disconnected devices + for id_ in disconnected_ids: + self.unpair_id(id_) + + return devices diff --git a/plugins/keepkey/client.py b/plugins/keepkey/client.py index 00de03be5..4e70955bb 100644 --- a/plugins/keepkey/client.py +++ b/plugins/keepkey/client.py @@ -2,10 +2,10 @@ from keepkeylib.client import proto, BaseClient, ProtocolMixin from ..trezor.clientbase import TrezorClientBase class KeepKeyClient(TrezorClientBase, ProtocolMixin, BaseClient): - def __init__(self, transport, handler, plugin, hid_id): + def __init__(self, transport, handler, plugin): BaseClient.__init__(self, transport) ProtocolMixin.__init__(self, transport) - TrezorClientBase.__init__(self, handler, plugin, hid_id, proto) + TrezorClientBase.__init__(self, handler, plugin, proto) def recovery_device(self, *args): ProtocolMixin.recovery_device(self, True, *args) diff --git a/plugins/trezor/client.py b/plugins/trezor/client.py index 3d1fe7555..591ea3037 100644 --- a/plugins/trezor/client.py +++ b/plugins/trezor/client.py @@ -2,10 +2,10 @@ from trezorlib.client import proto, BaseClient, ProtocolMixin from clientbase import TrezorClientBase class TrezorClient(TrezorClientBase, ProtocolMixin, BaseClient): - def __init__(self, transport, handler, plugin, hid_id): + def __init__(self, transport, handler, plugin): BaseClient.__init__(self, transport) ProtocolMixin.__init__(self, transport) - TrezorClientBase.__init__(self, handler, plugin, hid_id, proto) + TrezorClientBase.__init__(self, handler, plugin, proto) TrezorClientBase.wrap_methods(TrezorClient) diff --git a/plugins/trezor/clientbase.py b/plugins/trezor/clientbase.py index 8500bc570..1752f8ed7 100644 --- a/plugins/trezor/clientbase.py +++ b/plugins/trezor/clientbase.py @@ -68,27 +68,22 @@ class GuiMixin(object): class TrezorClientBase(GuiMixin, PrintError): - def __init__(self, handler, plugin, hid_id, proto): + def __init__(self, handler, plugin, proto): assert hasattr(self, 'tx_api') # ProtocolMixin already constructed? self.proto = proto self.device = plugin.device self.handler = handler - self.hid_id_ = hid_id self.tx_api = plugin self.types = plugin.types self.msg_code_override = None def __str__(self): - return "%s/%s" % (self.label(), self.hid_id()) + return "%s/%s" % (self.label(), self.features.device_id) def label(self): '''The name given by the user to the device.''' return self.features.label - def hid_id(self): - '''The HID ID of the device.''' - return self.hid_id_ - def is_initialized(self): '''True if initialized, False if wiped.''' return self.features.initialized @@ -163,7 +158,7 @@ class TrezorClientBase(GuiMixin, PrintError): def close(self): '''Called when Our wallet was closed or the device removed.''' - self.print_error("disconnected") + self.print_error("closing client") self.clear_session() # Release the device self.transport.close() diff --git a/plugins/trezor/plugin.py b/plugins/trezor/plugin.py index 9ea849f3e..fe90a77ee 100644 --- a/plugins/trezor/plugin.py +++ b/plugins/trezor/plugin.py @@ -51,23 +51,26 @@ class TrezorCompatibleWallet(BIP44_Wallet): self.session_timeout = seconds self.storage.put('session_timeout', seconds) - def disconnected(self): - '''A device paired with the wallet was diconnected. Note this is - called in the context of the Plugins thread.''' - self.print_error("disconnected") + def unpaired(self): + '''A device paired with the wallet was diconnected. This can be + called in any thread context.''' + self.print_error("unpaired") self.force_watching_only = True self.handler.watching_only_changed() - def connected(self): - '''A device paired with the wallet was (re-)connected. Note this - is called in the context of the Plugins thread.''' - self.print_error("connected") + def paired(self): + '''A device paired with the wallet was (re-)connected. This can be + called in any thread context.''' + self.print_error("paired") self.force_watching_only = False self.handler.watching_only_changed() def timeout(self): - '''Informs the wallet it timed out. Note this is called from + '''Called when the wallet session times out. Note this is called from the Plugins thread.''' + client = self.get_client(force_pair=False) + if client: + client.clear_session() self.print_error("timed out") def get_action(self): @@ -178,8 +181,7 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob): self.wallet_class.plugin = self self.prevent_timeout = time.time() + 3600 * 24 * 365 if self.libraries_available: - self.device_manager().register_devices( - self.DEVICE_IDS, self.create_client) + self.device_manager().register_devices(self.DEVICE_IDS) def is_enabled(self): return self.libraries_available @@ -199,13 +201,11 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob): if (isinstance(wallet, self.wallet_class) and hasattr(wallet, 'last_operation') and now > wallet.last_operation + wallet.session_timeout): - client = self.get_client(wallet, force_pair=False) - if client: - client.clear_session() - wallet.last_operation = self.prevent_timeout - wallet.timeout() + wallet.timeout() + wallet.last_operation = self.prevent_timeout - def create_client(self, path, handler, hid_id): + def create_client(self, device, handler): + path = device.path pair = ((None, path) if self.HidTransport._detect_debuglink(path) else (path, None)) try: @@ -215,50 +215,48 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob): self.print_error("cannot connect at", path, str(e)) return None self.print_error("connected to device at", path) - return self.client_class(transport, handler, self, hid_id) - def get_client(self, wallet, force_pair=True, check_firmware=True): + client = self.client_class(transport, handler, self) + + # Try a ping for device sanity + try: + client.ping('t') + except BaseException as e: + self.print_error("ping failed", str(e)) + return None + + if not client.atleast_version(*self.minimum_firmware): + msg = (_('Outdated %s firmware for device labelled %s. Please ' + 'download the updated firmware from %s') % + (self.device, client.label(), self.firmware_URL)) + handler.show_error(msg) + return None + + return client + + def get_client(self, wallet, force_pair=True): + # All client interaction should not be in the main GUI thread assert self.main_thread != threading.current_thread() - '''check_firmware is ignored unless force_pair is True.''' - client = self.device_manager().get_client(wallet, force_pair) + devmgr = self.device_manager() + client = devmgr.client_for_wallet(self, wallet, force_pair) - # Try a ping for device sanity if client: self.print_error("set last_operation") wallet.last_operation = time.time() - try: - client.ping('t') - except BaseException as e: - self.print_error("ping failed", str(e)) - # Remove it from the manager's cache - self.device_manager().close_client(client) - client = None - - if force_pair: - assert wallet.handler - if not client: - msg = (_('Could not connect to your %s. Verify the ' - 'cable is connected and that no other app is ' - 'using it.\nContinuing in watching-only mode ' - 'until the device is re-connected.') % self.device) - wallet.handler.show_error(msg) - raise DeviceDisconnectedError(msg) - - if (check_firmware and not - client.atleast_version(*self.minimum_firmware)): - msg = (_('Outdated %s firmware for device labelled %s. Please ' - 'download the updated firmware from %s') % - (self.device, client.label(), self.firmware_URL)) - wallet.handler.show_error(msg) - raise OutdatedFirmwareError(msg) + elif force_pair: + msg = (_('Could not connect to your %s. Verify the ' + 'cable is connected and that no other app is ' + 'using it.\nContinuing in watching-only mode ' + 'until the device is re-connected.') % self.device) + raise DeviceDisconnectedError(msg) return client @hook def close_wallet(self, wallet): if isinstance(wallet, self.wallet_class): - self.device_manager().close_wallet(wallet) + self.device_manager().unpair_wallet(wallet) def initialize_device(self, wallet): # Prevent timeouts during initialization @@ -310,27 +308,32 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob): wallet.thread.add(initialize_device) - def unpaired_clients(self, handler): + def unpaired_devices(self, handler): '''Returns all connected, unpaired devices as a list of clients and a list of descriptions.''' devmgr = self.device_manager() - clients = devmgr.unpaired_clients(handler, self.client_class) - states = [_("wiped"), _("initialized")] - def client_desc(client): - label = client.label() or _("An unnamed device") + devices = devmgr.unpaired_devices(handler) + + good_devices, descrs = [], [] + for device in devices: + client = self.device_manager().create_client(device, handler, self) + if not client: + continue state = states[client.is_initialized()] - return ("%s: serial number %s (%s)" - % (label, client.hid_id(), state)) - return clients, list(map(client_desc, clients)) + label = device.info['label'] or _("An unnamed device") + good_devices.append(device) + descrs.append("%s: device ID %s (%s)" % (label, device.id_, state)) + + return good_devices, descrs def select_device(self, wallet): '''Called when creating a new wallet. Select the device to use. If the device is uninitialized, go through the intialization process.''' msg = _("Please select which %s device to use:") % self.device - clients, labels = self.unpaired_clients(wallet.handler) - client = clients[wallet.handler.query_choice(msg, labels)] - self.device_manager().pair_wallet(wallet, client) + devices, labels = self.unpaired_devices(wallet.handler) + device = devices[wallet.handler.query_choice(msg, labels)] + self.device_manager().pair_wallet(wallet, device.id_) if not client.is_initialized(): self.initialize_device(wallet) diff --git a/plugins/trezor/qt_generic.py b/plugins/trezor/qt_generic.py index 9dd54c0db..c9c2b7019 100644 --- a/plugins/trezor/qt_generic.py +++ b/plugins/trezor/qt_generic.py @@ -252,25 +252,25 @@ def qt_plugin_class(base_plugin_class): menu.addAction(_("Show on %s") % self.device, show_address) def settings_dialog(self, window): - hid_id = self.choose_device(window) - if hid_id: - SettingsDialog(window, self, hid_id).exec_() + device_id = self.choose_device(window) + if device_id: + SettingsDialog(window, self, device_id).exec_() def choose_device(self, window): '''This dialog box should be usable even if the user has forgotten their PIN or it is in bootloader mode.''' handler = window.wallet.handler - hid_id = self.device_manager().wallet_hid_id(window.wallet) - if not hid_id: - clients, labels = self.unpaired_clients(handler) - if clients: + device_id = self.device_manager().wallet_id(window.wallet) + if not device_id: + devices, labels = self.unpaired_devices(handler) + if devices: msg = _("Select a %s device:") % self.device choice = self.query_choice(window, msg, labels) if choice is not None: - hid_id = clients[choice].hid_id() + device_id = devices[choice].id_ else: handler.show_error(_("No devices found")) - return hid_id + return device_id def query_choice(self, window, msg, choices): dialog = WindowModalDialog(window) @@ -292,28 +292,29 @@ class SettingsDialog(WindowModalDialog): We want users to be able to wipe a device even if they've forgotten their PIN.''' - def __init__(self, window, plugin, hid_id): + def __init__(self, window, plugin, device_id): title = _("%s Settings") % plugin.device super(SettingsDialog, self).__init__(window, title) self.setMaximumWidth(540) devmgr = plugin.device_manager() handler = window.wallet.handler + thread = window.wallet.thread # wallet can be None, needn't be window.wallet - wallet = devmgr.wallet_by_hid_id(hid_id) + wallet = devmgr.wallet_by_id(device_id) hs_rows, hs_cols = (64, 128) self.current_label=None def invoke_client(method, *args, **kw_args): def task(): - client = plugin.get_client(wallet, False) + client = devmgr.client_by_id(device_id, handler) if not client: raise RuntimeError("Device not connected") if method: getattr(client, method)(*args, **kw_args) update(client.features) - wallet.thread.add(task) + thread.add(task) def update(features): self.current_label = features.label @@ -364,7 +365,7 @@ class SettingsDialog(WindowModalDialog): if not self.question(msg, title=title): return invoke_client('toggle_passphrase') - devmgr.unpair(hid_id) + devmgr.unpair(device_id) def change_homescreen(): from PIL import Image # FIXME @@ -402,7 +403,7 @@ class SettingsDialog(WindowModalDialog): icon=QMessageBox.Critical): return invoke_client('wipe_device') - devmgr.unpair(hid_id) + devmgr.unpair(device_id) def slider_moved(): mins = timeout_slider.sliderPosition()