Browse Source

Clean up client caching and handling

283
Neil Booth 9 years ago
parent
commit
24037be99c
  1. 230
      lib/plugins.py
  2. 4
      plugins/keepkey/client.py
  3. 4
      plugins/trezor/client.py
  4. 11
      plugins/trezor/clientbase.py
  5. 121
      plugins/trezor/plugin.py
  6. 31
      plugins/trezor/qt_generic.py

230
lib/plugins.py

@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import namedtuple
import traceback import traceback
import sys import sys
import os import os
@ -226,6 +227,7 @@ class BasePlugin(PrintError):
def settings_dialog(self): def settings_dialog(self):
pass pass
Device = namedtuple("Device", "path id_ product_key")
class DeviceMgr(PrintError): class DeviceMgr(PrintError):
'''Manages hardware clients. A client communicates over a hardware '''Manages hardware clients. A client communicates over a hardware
@ -262,82 +264,115 @@ class DeviceMgr(PrintError):
def __init__(self): def __init__(self):
super(DeviceMgr, self).__init__() super(DeviceMgr, self).__init__()
# Keyed by wallet. The value is the hid_id if the wallet has # Keyed by wallet. The value is the device id if the wallet
# been paired, and None otherwise. # has been paired, and None otherwise.
self.wallets = {} self.wallets = {}
# A list of clients. We create a client for every device present # A list of clients. The key is the client, the value is
# that is of a registered hardware type # a (path, id_) pair.
self.clients = [] self.clients = {}
# What we recognise. Keyed by (vendor_id, product_id) pairs, # What we recognise. Each entry is a (vendor_id, product_id)
# the value is a callback to create a client for those devices # pair.
self.recognised_hardware = {} self.recognised_hardware = set()
# For synchronization # For synchronization
self.lock = threading.RLock() self.lock = threading.RLock()
def register_devices(self, device_pairs, create_client): def register_devices(self, device_pairs):
for pair in device_pairs: for pair in device_pairs:
self.recognised_hardware[pair] = create_client self.recognised_hardware.add(pair)
def unpair(self, hid_id): def create_client(self, device, handler, plugin):
client = plugin.create_client(device, handler)
if client:
self.print_error("Registering", client)
with self.lock:
self.clients[client] = (device.path, device.id_)
return client
def wallet_id(self, wallet):
with self.lock: with self.lock:
wallet = self.wallet_by_hid_id(hid_id) return self.wallets.get(wallet)
if wallet:
self.wallets[wallet] = None
def close_client(self, client): def wallet_by_id(self, id_):
with self.lock: with self.lock:
if client in self.clients: for wallet, wallet_id in self.wallets.items():
self.clients.remove(client) if wallet_id == id_:
return wallet
return None
def unpair_wallet(self, wallet):
with self.lock:
wallet_id = self.wallets.pop(wallet)
client = self.client_lookup(wallet_id)
self.clients.pop(client, None)
wallet.unpaired()
if client: if client:
client.close() client.close()
def close_wallet(self, wallet): def unpair_id(self, id_):
# Remove the wallet from our list; close any client
with self.lock: with self.lock:
hid_id = self.wallets.pop(wallet, None) wallet = self.wallet_by_id(id_)
self.close_client(self.client_by_hid_id(hid_id)) if wallet:
self.unpair_wallet(wallet)
def unpaired_clients(self, handler, classinfo): def pair_wallet(self, wallet, id_):
'''Returns all unpaired clients of the given type.'''
self.scan_devices(handler)
with self.lock: with self.lock:
return [client for client in self.clients self.wallets[wallet] = id_
if isinstance(client, classinfo) wallet.paired()
and not self.wallet_by_hid_id(client.hid_id())]
def client_by_hid_id(self, hid_id, handler=None): def paired_wallets(self):
'''Like get_client() but when we don't care about wallet pairing. If return list(self.wallets.keys())
a device is wiped or in bootloader mode pairing is impossible;
in such cases we communicate by device ID and not wallet.''' def client_lookup(self, id_):
if handler:
self.scan_devices(handler)
with self.lock: with self.lock:
for client in self.clients: for client, (path, client_id) in self.clients.items():
if client.hid_id() == hid_id: if client_id == id_:
return client return client
return None return None
def wallet_hid_id(self, wallet): def client_by_id(self, id_, handler):
with self.lock: '''Returns a client for the device ID if one is registered. If
return self.wallets.get(wallet) a device is wiped or in bootloader mode pairing is impossible;
in such cases we communicate by device ID and not wallet.'''
self.scan_devices(handler)
return self.client_lookup(id_)
def wallet_by_hid_id(self, hid_id): def client_for_wallet(self, plugin, wallet, force_pair):
with self.lock: assert wallet.handler
for wallet, wallet_hid_id in self.wallets.items():
if wallet_hid_id == hid_id:
return wallet
return None
def paired_wallets(self): devices = self.scan_devices(wallet.handler)
with self.lock: wallet_id = self.wallet_id(wallet)
return [wallet for (wallet, hid_id) in self.wallets.items()
if hid_id is not None] client = self.client_lookup(wallet_id)
if client:
return client
for device in devices:
if device.id_ == wallet_id:
return self.create_client(device, wallet.handler, plugin)
if force_pair:
first_address, derivation = wallet.first_address()
# Wallets don't have a first address in the install wizard
# until account creation
if not first_address:
self.print_error("no first address for ", wallet)
return None
# The wallet has not been previously paired, so get the
# first address of all unpaired clients and compare.
for device in devices:
# Skip already-paired devices
if self.wallet_by_id(device.id_):
continue
client = self.create_client(device, wallet.handler, plugin)
if client and not client.features.bootloader_mode:
# This will trigger a PIN/passphrase entry request
client_first_address = client.first_address(derivation)
if client_first_address == first_address:
self.pair_wallet(wallet, device.id_)
return client
def pair_wallet(self, wallet, client): return None
assert client in self.clients
self.print_error("paired:", wallet, client)
self.wallets[wallet] = client.hid_id()
wallet.connected()
def scan_devices(self, handler): def scan_devices(self, handler):
# All currently supported hardware libraries use hid, so we # All currently supported hardware libraries use hid, so we
@ -349,76 +384,27 @@ class DeviceMgr(PrintError):
self.print_error("scanning devices...") self.print_error("scanning devices...")
# First see what's connected that we know about # First see what's connected that we know about
devices = {} devices = []
for d in hid.enumerate(0, 0): for d in hid.enumerate(0, 0):
product_key = (d['vendor_id'], d['product_id']) product_key = (d['vendor_id'], d['product_id'])
create_client = self.recognised_hardware.get(product_key) if product_key in self.recognised_hardware:
if create_client: devices.append(Device(d['path'], d['serial_number'],
devices[d['serial_number']] = (create_client, d['path']) product_key))
# Now find out what was disconnected # Now find out what was disconnected
pairs = [(dev.path, dev.id_) for dev in devices]
disconnected_ids = []
with self.lock: with self.lock:
disconnected = [client for client in self.clients connected = {}
if not client.hid_id() in devices] for client, pair in self.clients.items():
if pair in pairs:
# Close disconnected clients after informing their wallets connected[client] = pair
for client in disconnected: else:
wallet = self.wallet_by_hid_id(client.hid_id()) disconnected_ids.append(pair[1])
if wallet: self.clients = connected
wallet.disconnected()
self.close_client(client) # Unpair disconnected devices
for id_ in disconnected_ids:
# Now see if any new devices are present. self.unpair_id(id_)
for hid_id, (create_client, path) in devices.items():
try: return devices
client = create_client(path, handler, hid_id)
except BaseException as e:
self.print_error("could not create client", str(e))
client = None
if client:
self.print_error("client created for", path)
with self.lock:
self.clients.append(client)
# Inform re-paired wallet
wallet = self.wallet_by_hid_id(hid_id)
if wallet:
self.pair_wallet(wallet, client)
def get_client(self, wallet, force_pair=True):
'''Returns a client for the wallet, or None if one could not be found.
If force_pair is False then if an already paired client cannot
be found None is returned rather than requiring user
interaction.'''
# We must scan devices to get an up-to-date idea of which
# devices are present. Operating on a client when its device
# has been removed can cause the process to hang.
# Unfortunately there is no plugged / unplugged notification
# system.
self.scan_devices(wallet.handler)
# Previously paired wallets only need look for matching HID IDs
hid_id = self.wallet_hid_id(wallet)
if hid_id:
return self.client_by_hid_id(hid_id)
first_address, derivation = wallet.first_address()
# Wallets don't have a first address in the install wizard
# until account creation
if not first_address:
self.print_error("no first address for ", wallet)
return None
with self.lock:
# The wallet has not been previously paired, so get the
# first address of all unpaired clients and compare.
for client in self.clients:
# If already paired skip it
if self.wallet_by_hid_id(client.hid_id()):
continue
# This will trigger a PIN/passphrase entry request
if client.first_address(derivation) == first_address:
self.pair_wallet(wallet, client)
return client
# Not found
return None

4
plugins/keepkey/client.py

@ -2,10 +2,10 @@ from keepkeylib.client import proto, BaseClient, ProtocolMixin
from ..trezor.clientbase import TrezorClientBase from ..trezor.clientbase import TrezorClientBase
class KeepKeyClient(TrezorClientBase, ProtocolMixin, BaseClient): class KeepKeyClient(TrezorClientBase, ProtocolMixin, BaseClient):
def __init__(self, transport, handler, plugin, hid_id): def __init__(self, transport, handler, plugin):
BaseClient.__init__(self, transport) BaseClient.__init__(self, transport)
ProtocolMixin.__init__(self, transport) ProtocolMixin.__init__(self, transport)
TrezorClientBase.__init__(self, handler, plugin, hid_id, proto) TrezorClientBase.__init__(self, handler, plugin, proto)
def recovery_device(self, *args): def recovery_device(self, *args):
ProtocolMixin.recovery_device(self, True, *args) ProtocolMixin.recovery_device(self, True, *args)

4
plugins/trezor/client.py

@ -2,10 +2,10 @@ from trezorlib.client import proto, BaseClient, ProtocolMixin
from clientbase import TrezorClientBase from clientbase import TrezorClientBase
class TrezorClient(TrezorClientBase, ProtocolMixin, BaseClient): class TrezorClient(TrezorClientBase, ProtocolMixin, BaseClient):
def __init__(self, transport, handler, plugin, hid_id): def __init__(self, transport, handler, plugin):
BaseClient.__init__(self, transport) BaseClient.__init__(self, transport)
ProtocolMixin.__init__(self, transport) ProtocolMixin.__init__(self, transport)
TrezorClientBase.__init__(self, handler, plugin, hid_id, proto) TrezorClientBase.__init__(self, handler, plugin, proto)
TrezorClientBase.wrap_methods(TrezorClient) TrezorClientBase.wrap_methods(TrezorClient)

11
plugins/trezor/clientbase.py

@ -68,27 +68,22 @@ class GuiMixin(object):
class TrezorClientBase(GuiMixin, PrintError): class TrezorClientBase(GuiMixin, PrintError):
def __init__(self, handler, plugin, hid_id, proto): def __init__(self, handler, plugin, proto):
assert hasattr(self, 'tx_api') # ProtocolMixin already constructed? assert hasattr(self, 'tx_api') # ProtocolMixin already constructed?
self.proto = proto self.proto = proto
self.device = plugin.device self.device = plugin.device
self.handler = handler self.handler = handler
self.hid_id_ = hid_id
self.tx_api = plugin self.tx_api = plugin
self.types = plugin.types self.types = plugin.types
self.msg_code_override = None self.msg_code_override = None
def __str__(self): def __str__(self):
return "%s/%s" % (self.label(), self.hid_id()) return "%s/%s" % (self.label(), self.features.device_id)
def label(self): def label(self):
'''The name given by the user to the device.''' '''The name given by the user to the device.'''
return self.features.label return self.features.label
def hid_id(self):
'''The HID ID of the device.'''
return self.hid_id_
def is_initialized(self): def is_initialized(self):
'''True if initialized, False if wiped.''' '''True if initialized, False if wiped.'''
return self.features.initialized return self.features.initialized
@ -163,7 +158,7 @@ class TrezorClientBase(GuiMixin, PrintError):
def close(self): def close(self):
'''Called when Our wallet was closed or the device removed.''' '''Called when Our wallet was closed or the device removed.'''
self.print_error("disconnected") self.print_error("closing client")
self.clear_session() self.clear_session()
# Release the device # Release the device
self.transport.close() self.transport.close()

121
plugins/trezor/plugin.py

@ -51,23 +51,26 @@ class TrezorCompatibleWallet(BIP44_Wallet):
self.session_timeout = seconds self.session_timeout = seconds
self.storage.put('session_timeout', seconds) self.storage.put('session_timeout', seconds)
def disconnected(self): def unpaired(self):
'''A device paired with the wallet was diconnected. Note this is '''A device paired with the wallet was diconnected. This can be
called in the context of the Plugins thread.''' called in any thread context.'''
self.print_error("disconnected") self.print_error("unpaired")
self.force_watching_only = True self.force_watching_only = True
self.handler.watching_only_changed() self.handler.watching_only_changed()
def connected(self): def paired(self):
'''A device paired with the wallet was (re-)connected. Note this '''A device paired with the wallet was (re-)connected. This can be
is called in the context of the Plugins thread.''' called in any thread context.'''
self.print_error("connected") self.print_error("paired")
self.force_watching_only = False self.force_watching_only = False
self.handler.watching_only_changed() self.handler.watching_only_changed()
def timeout(self): def timeout(self):
'''Informs the wallet it timed out. Note this is called from '''Called when the wallet session times out. Note this is called from
the Plugins thread.''' the Plugins thread.'''
client = self.get_client(force_pair=False)
if client:
client.clear_session()
self.print_error("timed out") self.print_error("timed out")
def get_action(self): def get_action(self):
@ -178,8 +181,7 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
self.wallet_class.plugin = self self.wallet_class.plugin = self
self.prevent_timeout = time.time() + 3600 * 24 * 365 self.prevent_timeout = time.time() + 3600 * 24 * 365
if self.libraries_available: if self.libraries_available:
self.device_manager().register_devices( self.device_manager().register_devices(self.DEVICE_IDS)
self.DEVICE_IDS, self.create_client)
def is_enabled(self): def is_enabled(self):
return self.libraries_available return self.libraries_available
@ -199,13 +201,11 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
if (isinstance(wallet, self.wallet_class) if (isinstance(wallet, self.wallet_class)
and hasattr(wallet, 'last_operation') and hasattr(wallet, 'last_operation')
and now > wallet.last_operation + wallet.session_timeout): and now > wallet.last_operation + wallet.session_timeout):
client = self.get_client(wallet, force_pair=False) wallet.timeout()
if client: wallet.last_operation = self.prevent_timeout
client.clear_session()
wallet.last_operation = self.prevent_timeout
wallet.timeout()
def create_client(self, path, handler, hid_id): def create_client(self, device, handler):
path = device.path
pair = ((None, path) if self.HidTransport._detect_debuglink(path) pair = ((None, path) if self.HidTransport._detect_debuglink(path)
else (path, None)) else (path, None))
try: try:
@ -215,50 +215,48 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
self.print_error("cannot connect at", path, str(e)) self.print_error("cannot connect at", path, str(e))
return None return None
self.print_error("connected to device at", path) self.print_error("connected to device at", path)
return self.client_class(transport, handler, self, hid_id)
def get_client(self, wallet, force_pair=True, check_firmware=True): client = self.client_class(transport, handler, self)
# Try a ping for device sanity
try:
client.ping('t')
except BaseException as e:
self.print_error("ping failed", str(e))
return None
if not client.atleast_version(*self.minimum_firmware):
msg = (_('Outdated %s firmware for device labelled %s. Please '
'download the updated firmware from %s') %
(self.device, client.label(), self.firmware_URL))
handler.show_error(msg)
return None
return client
def get_client(self, wallet, force_pair=True):
# All client interaction should not be in the main GUI thread
assert self.main_thread != threading.current_thread() assert self.main_thread != threading.current_thread()
'''check_firmware is ignored unless force_pair is True.''' devmgr = self.device_manager()
client = self.device_manager().get_client(wallet, force_pair) client = devmgr.client_for_wallet(self, wallet, force_pair)
# Try a ping for device sanity
if client: if client:
self.print_error("set last_operation") self.print_error("set last_operation")
wallet.last_operation = time.time() wallet.last_operation = time.time()
try: elif force_pair:
client.ping('t') msg = (_('Could not connect to your %s. Verify the '
except BaseException as e: 'cable is connected and that no other app is '
self.print_error("ping failed", str(e)) 'using it.\nContinuing in watching-only mode '
# Remove it from the manager's cache 'until the device is re-connected.') % self.device)
self.device_manager().close_client(client) raise DeviceDisconnectedError(msg)
client = None
if force_pair:
assert wallet.handler
if not client:
msg = (_('Could not connect to your %s. Verify the '
'cable is connected and that no other app is '
'using it.\nContinuing in watching-only mode '
'until the device is re-connected.') % self.device)
wallet.handler.show_error(msg)
raise DeviceDisconnectedError(msg)
if (check_firmware and not
client.atleast_version(*self.minimum_firmware)):
msg = (_('Outdated %s firmware for device labelled %s. Please '
'download the updated firmware from %s') %
(self.device, client.label(), self.firmware_URL))
wallet.handler.show_error(msg)
raise OutdatedFirmwareError(msg)
return client return client
@hook @hook
def close_wallet(self, wallet): def close_wallet(self, wallet):
if isinstance(wallet, self.wallet_class): if isinstance(wallet, self.wallet_class):
self.device_manager().close_wallet(wallet) self.device_manager().unpair_wallet(wallet)
def initialize_device(self, wallet): def initialize_device(self, wallet):
# Prevent timeouts during initialization # Prevent timeouts during initialization
@ -310,27 +308,32 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
wallet.thread.add(initialize_device) wallet.thread.add(initialize_device)
def unpaired_clients(self, handler): def unpaired_devices(self, handler):
'''Returns all connected, unpaired devices as a list of clients and a '''Returns all connected, unpaired devices as a list of clients and a
list of descriptions.''' list of descriptions.'''
devmgr = self.device_manager() devmgr = self.device_manager()
clients = devmgr.unpaired_clients(handler, self.client_class) devices = devmgr.unpaired_devices(handler)
states = [_("wiped"), _("initialized")]
def client_desc(client): good_devices, descrs = [], []
label = client.label() or _("An unnamed device") for device in devices:
client = self.device_manager().create_client(device, handler, self)
if not client:
continue
state = states[client.is_initialized()] state = states[client.is_initialized()]
return ("%s: serial number %s (%s)" label = device.info['label'] or _("An unnamed device")
% (label, client.hid_id(), state)) good_devices.append(device)
return clients, list(map(client_desc, clients)) descrs.append("%s: device ID %s (%s)" % (label, device.id_, state))
return good_devices, descrs
def select_device(self, wallet): def select_device(self, wallet):
'''Called when creating a new wallet. Select the device to use. If '''Called when creating a new wallet. Select the device to use. If
the device is uninitialized, go through the intialization the device is uninitialized, go through the intialization
process.''' process.'''
msg = _("Please select which %s device to use:") % self.device msg = _("Please select which %s device to use:") % self.device
clients, labels = self.unpaired_clients(wallet.handler) devices, labels = self.unpaired_devices(wallet.handler)
client = clients[wallet.handler.query_choice(msg, labels)] device = devices[wallet.handler.query_choice(msg, labels)]
self.device_manager().pair_wallet(wallet, client) self.device_manager().pair_wallet(wallet, device.id_)
if not client.is_initialized(): if not client.is_initialized():
self.initialize_device(wallet) self.initialize_device(wallet)

31
plugins/trezor/qt_generic.py

@ -252,25 +252,25 @@ def qt_plugin_class(base_plugin_class):
menu.addAction(_("Show on %s") % self.device, show_address) menu.addAction(_("Show on %s") % self.device, show_address)
def settings_dialog(self, window): def settings_dialog(self, window):
hid_id = self.choose_device(window) device_id = self.choose_device(window)
if hid_id: if device_id:
SettingsDialog(window, self, hid_id).exec_() SettingsDialog(window, self, device_id).exec_()
def choose_device(self, window): def choose_device(self, window):
'''This dialog box should be usable even if the user has '''This dialog box should be usable even if the user has
forgotten their PIN or it is in bootloader mode.''' forgotten their PIN or it is in bootloader mode.'''
handler = window.wallet.handler handler = window.wallet.handler
hid_id = self.device_manager().wallet_hid_id(window.wallet) device_id = self.device_manager().wallet_id(window.wallet)
if not hid_id: if not device_id:
clients, labels = self.unpaired_clients(handler) devices, labels = self.unpaired_devices(handler)
if clients: if devices:
msg = _("Select a %s device:") % self.device msg = _("Select a %s device:") % self.device
choice = self.query_choice(window, msg, labels) choice = self.query_choice(window, msg, labels)
if choice is not None: if choice is not None:
hid_id = clients[choice].hid_id() device_id = devices[choice].id_
else: else:
handler.show_error(_("No devices found")) handler.show_error(_("No devices found"))
return hid_id return device_id
def query_choice(self, window, msg, choices): def query_choice(self, window, msg, choices):
dialog = WindowModalDialog(window) dialog = WindowModalDialog(window)
@ -292,28 +292,29 @@ class SettingsDialog(WindowModalDialog):
We want users to be able to wipe a device even if they've forgotten We want users to be able to wipe a device even if they've forgotten
their PIN.''' their PIN.'''
def __init__(self, window, plugin, hid_id): def __init__(self, window, plugin, device_id):
title = _("%s Settings") % plugin.device title = _("%s Settings") % plugin.device
super(SettingsDialog, self).__init__(window, title) super(SettingsDialog, self).__init__(window, title)
self.setMaximumWidth(540) self.setMaximumWidth(540)
devmgr = plugin.device_manager() devmgr = plugin.device_manager()
handler = window.wallet.handler handler = window.wallet.handler
thread = window.wallet.thread
# wallet can be None, needn't be window.wallet # wallet can be None, needn't be window.wallet
wallet = devmgr.wallet_by_hid_id(hid_id) wallet = devmgr.wallet_by_id(device_id)
hs_rows, hs_cols = (64, 128) hs_rows, hs_cols = (64, 128)
self.current_label=None self.current_label=None
def invoke_client(method, *args, **kw_args): def invoke_client(method, *args, **kw_args):
def task(): def task():
client = plugin.get_client(wallet, False) client = devmgr.client_by_id(device_id, handler)
if not client: if not client:
raise RuntimeError("Device not connected") raise RuntimeError("Device not connected")
if method: if method:
getattr(client, method)(*args, **kw_args) getattr(client, method)(*args, **kw_args)
update(client.features) update(client.features)
wallet.thread.add(task) thread.add(task)
def update(features): def update(features):
self.current_label = features.label self.current_label = features.label
@ -364,7 +365,7 @@ class SettingsDialog(WindowModalDialog):
if not self.question(msg, title=title): if not self.question(msg, title=title):
return return
invoke_client('toggle_passphrase') invoke_client('toggle_passphrase')
devmgr.unpair(hid_id) devmgr.unpair(device_id)
def change_homescreen(): def change_homescreen():
from PIL import Image # FIXME from PIL import Image # FIXME
@ -402,7 +403,7 @@ class SettingsDialog(WindowModalDialog):
icon=QMessageBox.Critical): icon=QMessageBox.Critical):
return return
invoke_client('wipe_device') invoke_client('wipe_device')
devmgr.unpair(hid_id) devmgr.unpair(device_id)
def slider_moved(): def slider_moved():
mins = timeout_slider.sliderPosition() mins = timeout_slider.sliderPosition()

Loading…
Cancel
Save