Browse Source

wallet: change _requests_addr_to_key map to multi-map

I find this easier to reason about than occasionally overwriting the items.
get_request_by_addr still only returns a single invoice for simplicity,
but now all logic regarding how to handle collisions is inside that method.
patch-4
SomberNight 2 years ago
parent
commit
30f3d27baa
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 41
      electrum/wallet.py

41
electrum/wallet.py

@ -1058,14 +1058,10 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def _init_requests_rhash_index(self):
# self._requests_addr_to_key may contain addresses that can be reused
# this is checked in get_request_by_address
self._requests_addr_to_key = {}
self._requests_addr_to_key = defaultdict(set) # type: Dict[str, Set[str]]
for req in self._receive_requests.values():
if req.is_lightning() and (addr:=req.get_address()):
# give priority to not-yet-expired requests, to postpone reusing the address
# FIXME maybe self._receive_requests should be a multi-map instead
if req.has_expired() and addr in self._requests_addr_to_key:
continue
self._requests_addr_to_key[addr] = req.get_id()
if addr := req.get_address():
self._requests_addr_to_key[addr].add(req.get_id())
def _prepare_onchain_invoice_paid_detection(self):
self._invoices_from_txid_map = defaultdict(set) # type: Dict[str, Set[str]]
@ -2367,18 +2363,27 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return self.check_expired_status(invoice, status)
def get_request_by_addr(self, addr: str) -> Optional[Invoice]:
"""
Called in get_label_for_address and update_invoices_and_req_touched_by_tx
"""Returns a relevant request for address, from an on-chain PoV.
(One that has been paid on-chain or is pending)
Called in get_label_for_address and update_invoices_and_reqs_touched_by_tx
Returns None if the address can be reused (i.e. was paid by lightning or has expired)
"""
key = self._requests_addr_to_key.get(addr)
req = self._receive_requests.get(key)
if req is None:
keys = self._requests_addr_to_key.get(addr) or []
reqs = [self._receive_requests.get(key) for key in keys]
reqs = [req for req in reqs if req] # filter None
if not reqs:
return
status = self.get_invoice_status(req)
if (status == PR_PAID and not self.adb.is_used(addr)) or (status == PR_EXPIRED):
# filter out expired
reqs = [req for req in reqs if self.get_invoice_status(req) != PR_EXPIRED]
# filter out paid-with-lightning
if self.lnworker:
reqs = [req for req in reqs
if not req.is_lightning() or self.lnworker.get_invoice_status(req) == PR_UNPAID]
if not reqs:
return None
return req
# note: there typically should not be more than one relevant request for an address
return reqs[0]
def get_request(self, request_id: str) -> Optional[Invoice]:
return self._receive_requests.get(request_id)
@ -2510,8 +2515,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
request_id = req.get_id()
self._receive_requests[request_id] = req
if addr:=req.get_address():
# may overwrite expired or ln-paid request
self._requests_addr_to_key[addr] = request_id
self._requests_addr_to_key[addr].add(request_id)
if write_to_disk:
self.save_db()
return request_id
@ -2523,8 +2527,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return
self._receive_requests.pop(request_id, None)
if addr:=req.get_address():
if self._requests_addr_to_key.get(addr) == request_id:
self._requests_addr_to_key.pop(addr)
self._requests_addr_to_key[addr].discard(request_id)
if req.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(req.rhash)
if write_to_disk:

Loading…
Cancel
Save