diff --git a/electrum/commands.py b/electrum/commands.py index 139efd55f..526611513 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -980,9 +980,9 @@ class Commands: wallet.sign_payment_request(address, alias, alias_addr, password) @command('w') - async def rmrequest(self, address, wallet: Abstract_Wallet = None): + async def delete_request(self, address, wallet: Abstract_Wallet = None): """Remove a payment request""" - return wallet.remove_payment_request(address) + return wallet.delete_request(address) @command('w') async def clear_requests(self, wallet: Abstract_Wallet = None): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 458154f10..b5e48672b 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -617,7 +617,7 @@ class LNWallet(LNWorker): self.config = wallet.config self.lnwatcher = None self.lnrater: LNRater = None - self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid + self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage # note: this sweep_address is only used as fallback; as it might result in address-reuse self.sweep_address = wallet.get_new_sweep_address_for_channel() @@ -818,7 +818,7 @@ class LNWallet(LNWorker): info = self.get_payment_info(payment_hash) amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist) if info is not None: - label = self.wallet.get_label(key) + label = self.wallet.get_label_for_rhash(key) direction = ('sent' if info.direction == SENT else 'received') if len(plist)==1 else 'self-payment' else: direction = 'forwarding' @@ -1851,15 +1851,15 @@ class LNWallet(LNWorker): """returns None if payment_hash is a payment we are forwarding""" key = payment_hash.hex() with self.lock: - if key in self.payments: - amount_msat, direction, status = self.payments[key] + if key in self.payment_info: + amount_msat, direction, status = self.payment_info[key] return PaymentInfo(payment_hash, amount_msat, direction, status) def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: key = info.payment_hash.hex() assert info.status in SAVED_PR_STATUS with self.lock: - self.payments[key] = info.amount_msat, info.direction, info.status + self.payment_info[key] = info.amount_msat, info.direction, info.status if write_to_disk: self.wallet.save_db() @@ -1916,11 +1916,12 @@ class LNWallet(LNWorker): util.trigger_callback('invoice_status', self.wallet, key) def set_request_status(self, payment_hash: bytes, status: int) -> None: - if self.get_payment_status(payment_hash) != status: - self.set_payment_status(payment_hash, status) - for key, req in self.wallet.receive_requests.items(): - if req.is_lightning() and req.rhash == payment_hash.hex(): - util.trigger_callback('request_status', self.wallet, key, status) + if self.get_payment_status(payment_hash) == status: + return + self.set_payment_status(payment_hash, status) + req = self.wallet.get_request_by_rhash(payment_hash.hex()) + key = self.wallet.get_key_for_receive_request(req) + util.trigger_callback('request_status', self.wallet, key, status) def set_payment_status(self, payment_hash: bytes, status: int) -> None: info = self.get_payment_info(payment_hash) @@ -2060,13 +2061,14 @@ class LNWallet(LNWorker): trampoline_hints.append(('t', (node_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta))) return routing_hints, trampoline_hints - def delete_payment(self, payment_hash_hex: str): - try: - with self.lock: - del self.payments[payment_hash_hex] - except KeyError: - return - self.wallet.save_db() + def delete_payment_info(self, payment_hash_hex: str): + # This method is called when an invoice or request is deleted by the user. + # The GUI only lets the user delete invoices or requests that have not been paid. + # Once an invoice/request has been paid, it is part of the history, + # and get_lightning_history assumes that payment_info is there. + assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID + with self.lock: + self.payment_info.pop(payment_hash_hex, None) def get_balance(self, frozen=False): with self.lock: diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 398ad6f96..5e1c83940 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -105,6 +105,12 @@ class MockWallet: receive_requests = {} adb = MockADB() + def get_request_by_rhash(self, rhash): + pass + + def get_key_for_receive_request(self, x): + pass + def set_label(self, x, y): pass @@ -134,7 +140,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.lnwatcher = None self.listen_server = None self._channels = {chan.channel_id: chan for chan in chans} - self.payments = {} + self.payment_info = {} self.logs = defaultdict(list) self.wallet = MockWallet() self.features = LnFeatures(0) diff --git a/electrum/wallet.py b/electrum/wallet.py index 1f7650b0d..e038ef7de 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -316,6 +316,7 @@ class Abstract_Wallet(ABC, Logger): self._freeze_lock = threading.RLock() # for mutating/iterating frozen_{addresses,coins} + self._init_requests_rhash_index() self._prepare_onchain_invoice_paid_detection() self.calc_unused_change_addresses() # save wallet type the first time @@ -967,6 +968,7 @@ class Abstract_Wallet(ABC, Logger): def clear_requests(self): self.receive_requests.clear() + self._requests_rhash_to_key.clear() self.save_db() def get_invoices(self): @@ -1017,6 +1019,12 @@ class Abstract_Wallet(ABC, Logger): assert isinstance(inv, Invoice), f"unexpected type {type(inv)}" return invoices + def _init_requests_rhash_index(self): + self._requests_rhash_to_key = {} + for key, req in self.receive_requests.items(): + if req.is_lightning(): + self._requests_rhash_to_key[req.rhash] = key + def _prepare_onchain_invoice_paid_detection(self): # scriptpubkey -> list(invoice_keys) self._invoices_from_scriptpubkey_map = defaultdict(set) # type: Dict[bytes, Set[str]] @@ -1326,6 +1334,13 @@ class Abstract_Wallet(ABC, Logger): return ', '.join(labels) return '' + def _get_default_label_for_rhash(self, rhash: str) -> str: + req = self.get_request_by_rhash(rhash) + return req.message if req else '' + + def get_label_for_rhash(self, rhash: str) -> str: + return self._labels.get(rhash) or self._get_default_label_for_rhash(rhash) + def get_all_labels(self) -> Dict[str, str]: with self.lock: return copy.copy(self._labels) @@ -2327,10 +2342,14 @@ class Abstract_Wallet(ABC, Logger): return self.check_expired_status(r, status) def get_request(self, key): - return self.receive_requests.get(key) + return self.receive_requests.get(key) or self.get_request_by_rhash(key) + + def get_request_by_rhash(self, rhash): + key = self._requests_rhash_to_key.get(rhash) + return self.receive_requests.get(key) if key else None def get_formatted_request(self, key): - x = self.receive_requests.get(key) + x = self.get_request.get(key) if x: return self.export_request(x) @@ -2473,32 +2492,31 @@ class Abstract_Wallet(ABC, Logger): key = self.get_key_for_receive_request(req, sanity_checks=True) message = req.message self.receive_requests[key] = req - self.set_label(key, message) # should be a default label + if req.is_lightning(): + self._requests_rhash_to_key[req.rhash] = key if write_to_disk: self.save_db() return key def delete_request(self, key): """ lightning or on-chain """ - if key in self.receive_requests: - self.remove_payment_request(key) - elif self.lnworker: - self.lnworker.delete_payment(key) + req = self.receive_requests.pop(key, None) + if req is None: + return + if req.is_lightning(): + self._requests_rhash_to_key.pop(req.rhash) + if req.is_lightning() and self.lnworker: + self.lnworker.delete_payment_info(req.rhash) + self.save_db() def delete_invoice(self, key): """ lightning or on-chain """ - if key in self.invoices: - self.invoices.pop(key) - elif self.lnworker: - self.lnworker.delete_payment(key) - - def remove_payment_request(self, addr) -> bool: - found = False - if addr in self.receive_requests: - found = True - self.receive_requests.pop(addr) - self.save_db() - return found + inv = self.invoices.pop(key, None) + if inv is None: + return + if inv.is_lightning() and self.lnworker: + self.lnworker.delete_payment_info(inv.rhash) + self.save_db() def get_sorted_requests(self) -> List[Invoice]: """ sorted by timestamp """ @@ -2914,7 +2932,7 @@ class Imported_Wallet(Simple_Wallet): for tx_hash in transactions_to_remove: self.adb._remove_transaction(tx_hash) self.set_label(address, None) - self.remove_payment_request(address) + self.delete_request(address) self.set_frozen_state_of_addresses([address], False) pubkey = self.get_public_key(address) self.db.remove_imported_address(address)