Browse Source

wallet: change _requests_addr_to_key map to multi-map

I find this easier to reason about than occasionally overwriting the items.
get_request_by_addr still only returns a single invoice for simplicity,
but now all logic regarding how to handle collisions is inside that method.
patch-4
SomberNight 2 years ago
parent
commit
30f3d27baa
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 41
      electrum/wallet.py

41
electrum/wallet.py

@ -1058,14 +1058,10 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def _init_requests_rhash_index(self): def _init_requests_rhash_index(self):
# self._requests_addr_to_key may contain addresses that can be reused # self._requests_addr_to_key may contain addresses that can be reused
# this is checked in get_request_by_address # this is checked in get_request_by_address
self._requests_addr_to_key = {} self._requests_addr_to_key = defaultdict(set) # type: Dict[str, Set[str]]
for req in self._receive_requests.values(): for req in self._receive_requests.values():
if req.is_lightning() and (addr:=req.get_address()): if addr := req.get_address():
# give priority to not-yet-expired requests, to postpone reusing the address self._requests_addr_to_key[addr].add(req.get_id())
# FIXME maybe self._receive_requests should be a multi-map instead
if req.has_expired() and addr in self._requests_addr_to_key:
continue
self._requests_addr_to_key[addr] = req.get_id()
def _prepare_onchain_invoice_paid_detection(self): def _prepare_onchain_invoice_paid_detection(self):
self._invoices_from_txid_map = defaultdict(set) # type: Dict[str, Set[str]] self._invoices_from_txid_map = defaultdict(set) # type: Dict[str, Set[str]]
@ -2367,18 +2363,27 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return self.check_expired_status(invoice, status) return self.check_expired_status(invoice, status)
def get_request_by_addr(self, addr: str) -> Optional[Invoice]: def get_request_by_addr(self, addr: str) -> Optional[Invoice]:
""" """Returns a relevant request for address, from an on-chain PoV.
Called in get_label_for_address and update_invoices_and_req_touched_by_tx (One that has been paid on-chain or is pending)
Called in get_label_for_address and update_invoices_and_reqs_touched_by_tx
Returns None if the address can be reused (i.e. was paid by lightning or has expired) Returns None if the address can be reused (i.e. was paid by lightning or has expired)
""" """
key = self._requests_addr_to_key.get(addr) keys = self._requests_addr_to_key.get(addr) or []
req = self._receive_requests.get(key) reqs = [self._receive_requests.get(key) for key in keys]
if req is None: reqs = [req for req in reqs if req] # filter None
if not reqs:
return return
status = self.get_invoice_status(req) # filter out expired
if (status == PR_PAID and not self.adb.is_used(addr)) or (status == PR_EXPIRED): reqs = [req for req in reqs if self.get_invoice_status(req) != PR_EXPIRED]
# filter out paid-with-lightning
if self.lnworker:
reqs = [req for req in reqs
if not req.is_lightning() or self.lnworker.get_invoice_status(req) == PR_UNPAID]
if not reqs:
return None return None
return req # note: there typically should not be more than one relevant request for an address
return reqs[0]
def get_request(self, request_id: str) -> Optional[Invoice]: def get_request(self, request_id: str) -> Optional[Invoice]:
return self._receive_requests.get(request_id) return self._receive_requests.get(request_id)
@ -2510,8 +2515,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
request_id = req.get_id() request_id = req.get_id()
self._receive_requests[request_id] = req self._receive_requests[request_id] = req
if addr:=req.get_address(): if addr:=req.get_address():
# may overwrite expired or ln-paid request self._requests_addr_to_key[addr].add(request_id)
self._requests_addr_to_key[addr] = request_id
if write_to_disk: if write_to_disk:
self.save_db() self.save_db()
return request_id return request_id
@ -2523,8 +2527,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return return
self._receive_requests.pop(request_id, None) self._receive_requests.pop(request_id, None)
if addr:=req.get_address(): if addr:=req.get_address():
if self._requests_addr_to_key.get(addr) == request_id: self._requests_addr_to_key[addr].discard(request_id)
self._requests_addr_to_key.pop(addr)
if req.is_lightning() and self.lnworker: if req.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(req.rhash) self.lnworker.delete_payment_info(req.rhash)
if write_to_disk: if write_to_disk:

Loading…
Cancel
Save