Browse Source

Merge branch '202211_wallet_payreq'

better handle collisions in _requests_addr_to_key
patch-4
SomberNight 2 years ago
parent
commit
28f724edc9
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 37
      electrum/wallet.py

37
electrum/wallet.py

@ -1058,10 +1058,10 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def _init_requests_rhash_index(self): def _init_requests_rhash_index(self):
# self._requests_addr_to_key may contain addresses that can be reused # self._requests_addr_to_key may contain addresses that can be reused
# this is checked in get_request_by_address # this is checked in get_request_by_address
self._requests_addr_to_key = {} self._requests_addr_to_key = defaultdict(set) # type: Dict[str, Set[str]]
for req in self._receive_requests.values(): for req in self._receive_requests.values():
if req.is_lightning() and (addr:=req.get_address()): if addr := req.get_address():
self._requests_addr_to_key[addr] = req.get_id() self._requests_addr_to_key[addr].add(req.get_id())
def _prepare_onchain_invoice_paid_detection(self): def _prepare_onchain_invoice_paid_detection(self):
self._invoices_from_txid_map = defaultdict(set) # type: Dict[str, Set[str]] self._invoices_from_txid_map = defaultdict(set) # type: Dict[str, Set[str]]
@ -2363,18 +2363,27 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return self.check_expired_status(invoice, status) return self.check_expired_status(invoice, status)
def get_request_by_addr(self, addr: str) -> Optional[Invoice]: def get_request_by_addr(self, addr: str) -> Optional[Invoice]:
""" """Returns a relevant request for address, from an on-chain PoV.
Called in get_label_for_address and update_invoices_and_req_touched_by_tx (One that has been paid on-chain or is pending)
Called in get_label_for_address and update_invoices_and_reqs_touched_by_tx
Returns None if the address can be reused (i.e. was paid by lightning or has expired) Returns None if the address can be reused (i.e. was paid by lightning or has expired)
""" """
key = self._requests_addr_to_key.get(addr) keys = self._requests_addr_to_key.get(addr) or []
req = self._receive_requests.get(key) reqs = [self._receive_requests.get(key) for key in keys]
if req is None: reqs = [req for req in reqs if req] # filter None
if not reqs:
return return
status = self.get_invoice_status(req) # filter out expired
if (status == PR_PAID and not self.adb.is_used(addr)) or (status == PR_EXPIRED): reqs = [req for req in reqs if self.get_invoice_status(req) != PR_EXPIRED]
# filter out paid-with-lightning
if self.lnworker:
reqs = [req for req in reqs
if not req.is_lightning() or self.lnworker.get_invoice_status(req) == PR_UNPAID]
if not reqs:
return None return None
return req # note: there typically should not be more than one relevant request for an address
return reqs[0]
def get_request(self, request_id: str) -> Optional[Invoice]: def get_request(self, request_id: str) -> Optional[Invoice]:
return self._receive_requests.get(request_id) return self._receive_requests.get(request_id)
@ -2506,8 +2515,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
request_id = req.get_id() request_id = req.get_id()
self._receive_requests[request_id] = req self._receive_requests[request_id] = req
if addr:=req.get_address(): if addr:=req.get_address():
# may overwrite expired or ln-paid request self._requests_addr_to_key[addr].add(request_id)
self._requests_addr_to_key[addr] = request_id
if write_to_disk: if write_to_disk:
self.save_db() self.save_db()
return request_id return request_id
@ -2519,8 +2527,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return return
self._receive_requests.pop(request_id, None) self._receive_requests.pop(request_id, None)
if addr:=req.get_address(): if addr:=req.get_address():
if self._requests_addr_to_key.get(addr) == request_id: self._requests_addr_to_key[addr].discard(request_id)
self._requests_addr_to_key.pop(addr)
if req.is_lightning() and self.lnworker: if req.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(req.rhash) self.lnworker.delete_payment_info(req.rhash)
if write_to_disk: if write_to_disk:

Loading…
Cancel
Save