|
|
@ -29,9 +29,10 @@ import time |
|
|
|
import threading |
|
|
|
import sys |
|
|
|
from typing import (NamedTuple, Any, Union, TYPE_CHECKING, Optional, Tuple, |
|
|
|
Dict, Iterable, List, Sequence) |
|
|
|
Dict, Iterable, List, Sequence, Callable, TypeVar) |
|
|
|
import concurrent |
|
|
|
from concurrent import futures |
|
|
|
from functools import wraps, partial |
|
|
|
|
|
|
|
from .i18n import _ |
|
|
|
from .util import (profiler, DaemonThread, UserCancelled, ThreadJob, UserFacingException) |
|
|
@ -334,11 +335,37 @@ PLACEHOLDER_HW_CLIENT_LABELS = {None, "", " "} |
|
|
|
# https://github.com/signal11/hidapi/pull/414#issuecomment-445164238 |
|
|
|
# It is not entirely clear to me, exactly what is safe and what isn't, when |
|
|
|
# using multiple threads... |
|
|
|
# For now, we use a dedicated thread to enumerate devices (_hid_executor), |
|
|
|
# and we synchronize all device opens/closes/enumeration (_hid_lock). |
|
|
|
# FIXME there are still probably threading issues with how we use hidapi... |
|
|
|
_hid_executor = None # type: Optional[concurrent.futures.Executor] |
|
|
|
_hid_lock = threading.Lock() |
|
|
|
# Hence, we use a single thread for all device communications, including |
|
|
|
# enumeration. Everything that uses hidapi, libusb, etc, MUST run on |
|
|
|
# the following thread: |
|
|
|
_hwd_comms_executor = concurrent.futures.ThreadPoolExecutor( |
|
|
|
max_workers=1, |
|
|
|
thread_name_prefix='hwd_comms_thread' |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
|
|
|
|
|
|
|
|
def run_in_hwd_thread(func: Callable[[], T]) -> T: |
|
|
|
if threading.current_thread().name.startswith("hwd_comms_thread"): |
|
|
|
return func() |
|
|
|
else: |
|
|
|
fut = _hwd_comms_executor.submit(func) |
|
|
|
return fut.result() |
|
|
|
#except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e: |
|
|
|
|
|
|
|
|
|
|
|
def runs_in_hwd_thread(func): |
|
|
|
@wraps(func) |
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
return run_in_hwd_thread(partial(func, *args, **kwargs)) |
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
def assert_runs_in_hwd_thread(): |
|
|
|
if not threading.current_thread().name.startswith("hwd_comms_thread"): |
|
|
|
raise Exception("must only be called from HWD communication thread") |
|
|
|
|
|
|
|
|
|
|
|
class DeviceMgr(ThreadJob): |
|
|
@ -384,24 +411,11 @@ class DeviceMgr(ThreadJob): |
|
|
|
self._recognised_hardware = {} # type: Dict[Tuple[int, int], HW_PluginBase] |
|
|
|
# Custom enumerate functions for devices we don't know about. |
|
|
|
self._enumerate_func = set() # Needs self.lock. |
|
|
|
# locks: if you need to take multiple ones, acquire them in the order they are defined here! |
|
|
|
self._scan_lock = threading.RLock() |
|
|
|
|
|
|
|
self.lock = threading.RLock() |
|
|
|
self.hid_lock = _hid_lock |
|
|
|
|
|
|
|
self.config = config |
|
|
|
|
|
|
|
global _hid_executor |
|
|
|
if _hid_executor is None: |
|
|
|
_hid_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, |
|
|
|
thread_name_prefix='hid_enumerate_thread') |
|
|
|
|
|
|
|
def with_scan_lock(func): |
|
|
|
def func_wrapper(self: 'DeviceMgr', *args, **kwargs): |
|
|
|
with self._scan_lock: |
|
|
|
return func(self, *args, **kwargs) |
|
|
|
return func_wrapper |
|
|
|
|
|
|
|
def thread_jobs(self): |
|
|
|
# Thread job to handle device timeouts |
|
|
|
return [self] |
|
|
@ -423,6 +437,7 @@ class DeviceMgr(ThreadJob): |
|
|
|
with self.lock: |
|
|
|
self._enumerate_func.add(func) |
|
|
|
|
|
|
|
@runs_in_hwd_thread |
|
|
|
def create_client(self, device: 'Device', handler: Optional['HardwareHandlerBase'], |
|
|
|
plugin: 'HW_PluginBase') -> Optional['HardwareClientBase']: |
|
|
|
# Get from cache first |
|
|
@ -462,6 +477,7 @@ class DeviceMgr(ThreadJob): |
|
|
|
self._close_client(id_) |
|
|
|
|
|
|
|
def _close_client(self, id_): |
|
|
|
with self.lock: |
|
|
|
client = self._client_by_id(id_) |
|
|
|
self.clients.pop(client, None) |
|
|
|
if client: |
|
|
@ -486,7 +502,7 @@ class DeviceMgr(ThreadJob): |
|
|
|
self.scan_devices() |
|
|
|
return self._client_by_id(id_) |
|
|
|
|
|
|
|
@with_scan_lock |
|
|
|
@runs_in_hwd_thread |
|
|
|
def client_for_keystore(self, plugin: 'HW_PluginBase', handler: Optional['HardwareHandlerBase'], |
|
|
|
keystore: 'Hardware_KeyStore', |
|
|
|
force_pair: bool, *, |
|
|
@ -655,25 +671,15 @@ class DeviceMgr(ThreadJob): |
|
|
|
# note: updated label/soft_device_id will be saved after pairing succeeds |
|
|
|
return info |
|
|
|
|
|
|
|
@with_scan_lock |
|
|
|
@runs_in_hwd_thread |
|
|
|
def _scan_devices_with_hid(self) -> List['Device']: |
|
|
|
try: |
|
|
|
import hid |
|
|
|
except ImportError: |
|
|
|
return [] |
|
|
|
|
|
|
|
def hid_enumerate(): |
|
|
|
with self.hid_lock: |
|
|
|
return hid.enumerate(0, 0) |
|
|
|
|
|
|
|
hid_list_fut = _hid_executor.submit(hid_enumerate) |
|
|
|
try: |
|
|
|
hid_list = hid_list_fut.result() |
|
|
|
except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e: |
|
|
|
return [] |
|
|
|
|
|
|
|
devices = [] |
|
|
|
for d in hid_list: |
|
|
|
for d in hid.enumerate(0, 0): |
|
|
|
product_key = (d['vendor_id'], d['product_id']) |
|
|
|
if product_key in self._recognised_hardware: |
|
|
|
plugin = self._recognised_hardware[product_key] |
|
|
@ -681,7 +687,7 @@ class DeviceMgr(ThreadJob): |
|
|
|
devices.append(device) |
|
|
|
return devices |
|
|
|
|
|
|
|
@with_scan_lock |
|
|
|
@runs_in_hwd_thread |
|
|
|
@profiler |
|
|
|
def scan_devices(self) -> Sequence['Device']: |
|
|
|
self.logger.info("scanning devices...") |
|
|
@ -693,10 +699,8 @@ class DeviceMgr(ThreadJob): |
|
|
|
with self.lock: |
|
|
|
enumerate_funcs = list(self._enumerate_func) |
|
|
|
for f in enumerate_funcs: |
|
|
|
# custom enumerate functions might use hidapi, so use hid thread to be safe |
|
|
|
new_devices_fut = _hid_executor.submit(f) |
|
|
|
try: |
|
|
|
new_devices = new_devices_fut.result() |
|
|
|
new_devices = f() |
|
|
|
except BaseException as e: |
|
|
|
self.logger.error('custom device enum failed. func {}, error {}' |
|
|
|
.format(str(f), repr(e))) |
|
|
|