From 1a8cc68f53cf59b91fcec09d375092730bd6ce72 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Fri, 18 Nov 2022 16:59:47 +0000 Subject: [PATCH 1/2] wallet: _requests_addr_to_key map to prefer unexpired reqs if collision --- electrum/wallet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/electrum/wallet.py b/electrum/wallet.py index 9a8cbe2fa..d5fd202da 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -1061,6 +1061,10 @@ class Abstract_Wallet(ABC, Logger, EventListener): self._requests_addr_to_key = {} for req in self._receive_requests.values(): if req.is_lightning() and (addr:=req.get_address()): + # give priority to not-yet-expired requests, to postpone reusing the address + # FIXME maybe self._receive_requests should be a multi-map instead + if req.has_expired() and addr in self._requests_addr_to_key: + continue self._requests_addr_to_key[addr] = req.get_id() def _prepare_onchain_invoice_paid_detection(self): From 30f3d27baa0c783d3439571766aa3cd3dd4a1089 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Fri, 18 Nov 2022 18:03:09 +0000 Subject: [PATCH 2/2] wallet: change _requests_addr_to_key map to multi-map I find this easier to reason about than occasionally overwriting the items. get_request_by_addr still only returns a single invoice for simplicity, but now all logic regarding how to handle collisions is inside that method. --- electrum/wallet.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/electrum/wallet.py b/electrum/wallet.py index d5fd202da..0afc3f5ae 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -1058,14 +1058,10 @@ class Abstract_Wallet(ABC, Logger, EventListener): def _init_requests_rhash_index(self): # self._requests_addr_to_key may contain addresses that can be reused # this is checked in get_request_by_address - self._requests_addr_to_key = {} + self._requests_addr_to_key = defaultdict(set) # type: Dict[str, Set[str]] for req in self._receive_requests.values(): - if req.is_lightning() and (addr:=req.get_address()): - # give priority to not-yet-expired requests, to postpone reusing the address - # FIXME maybe self._receive_requests should be a multi-map instead - if req.has_expired() and addr in self._requests_addr_to_key: - continue - self._requests_addr_to_key[addr] = req.get_id() + if addr := req.get_address(): + self._requests_addr_to_key[addr].add(req.get_id()) def _prepare_onchain_invoice_paid_detection(self): self._invoices_from_txid_map = defaultdict(set) # type: Dict[str, Set[str]] @@ -2367,18 +2363,27 @@ class Abstract_Wallet(ABC, Logger, EventListener): return self.check_expired_status(invoice, status) def get_request_by_addr(self, addr: str) -> Optional[Invoice]: - """ - Called in get_label_for_address and update_invoices_and_req_touched_by_tx + """Returns a relevant request for address, from an on-chain PoV. + (One that has been paid on-chain or is pending) + + Called in get_label_for_address and update_invoices_and_reqs_touched_by_tx Returns None if the address can be reused (i.e. was paid by lightning or has expired) """ - key = self._requests_addr_to_key.get(addr) - req = self._receive_requests.get(key) - if req is None: + keys = self._requests_addr_to_key.get(addr) or [] + reqs = [self._receive_requests.get(key) for key in keys] + reqs = [req for req in reqs if req] # filter None + if not reqs: return - status = self.get_invoice_status(req) - if (status == PR_PAID and not self.adb.is_used(addr)) or (status == PR_EXPIRED): + # filter out expired + reqs = [req for req in reqs if self.get_invoice_status(req) != PR_EXPIRED] + # filter out paid-with-lightning + if self.lnworker: + reqs = [req for req in reqs + if not req.is_lightning() or self.lnworker.get_invoice_status(req) == PR_UNPAID] + if not reqs: return None - return req + # note: there typically should not be more than one relevant request for an address + return reqs[0] def get_request(self, request_id: str) -> Optional[Invoice]: return self._receive_requests.get(request_id) @@ -2510,8 +2515,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): request_id = req.get_id() self._receive_requests[request_id] = req if addr:=req.get_address(): - # may overwrite expired or ln-paid request - self._requests_addr_to_key[addr] = request_id + self._requests_addr_to_key[addr].add(request_id) if write_to_disk: self.save_db() return request_id @@ -2523,8 +2527,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): return self._receive_requests.pop(request_id, None) if addr:=req.get_address(): - if self._requests_addr_to_key.get(addr) == request_id: - self._requests_addr_to_key.pop(addr) + self._requests_addr_to_key[addr].discard(request_id) if req.is_lightning() and self.lnworker: self.lnworker.delete_payment_info(req.rhash) if write_to_disk: