Browse Source

wallet:

- add new index: requests_rhash_to_key (fixes #7845)
 - when creating a request, do not save its description in labels.
   Instead, return it as default value in wallet.get_label_by_rhash
lnworker:
  - rename 'payments' to 'payment_info'
  - add note to delete_payment_info
commands: rename 'rmrequest' to 'delete_request'
patch-4
ThomasV 3 years ago
parent
commit
a3faf85e3c
  1. 4
      electrum/commands.py
  2. 36
      electrum/lnworker.py
  3. 8
      electrum/tests/test_lnpeer.py
  4. 58
      electrum/wallet.py

4
electrum/commands.py

@ -980,9 +980,9 @@ class Commands:
wallet.sign_payment_request(address, alias, alias_addr, password)
@command('w')
async def rmrequest(self, address, wallet: Abstract_Wallet = None):
async def delete_request(self, address, wallet: Abstract_Wallet = None):
"""Remove a payment request"""
return wallet.remove_payment_request(address)
return wallet.delete_request(address)
@command('w')
async def clear_requests(self, wallet: Abstract_Wallet = None):

36
electrum/lnworker.py

@ -617,7 +617,7 @@ class LNWallet(LNWorker):
self.config = wallet.config
self.lnwatcher = None
self.lnrater: LNRater = None
self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
# note: this sweep_address is only used as fallback; as it might result in address-reuse
self.sweep_address = wallet.get_new_sweep_address_for_channel()
@ -818,7 +818,7 @@ class LNWallet(LNWorker):
info = self.get_payment_info(payment_hash)
amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist)
if info is not None:
label = self.wallet.get_label(key)
label = self.wallet.get_label_for_rhash(key)
direction = ('sent' if info.direction == SENT else 'received') if len(plist)==1 else 'self-payment'
else:
direction = 'forwarding'
@ -1851,15 +1851,15 @@ class LNWallet(LNWorker):
"""returns None if payment_hash is a payment we are forwarding"""
key = payment_hash.hex()
with self.lock:
if key in self.payments:
amount_msat, direction, status = self.payments[key]
if key in self.payment_info:
amount_msat, direction, status = self.payment_info[key]
return PaymentInfo(payment_hash, amount_msat, direction, status)
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
key = info.payment_hash.hex()
assert info.status in SAVED_PR_STATUS
with self.lock:
self.payments[key] = info.amount_msat, info.direction, info.status
self.payment_info[key] = info.amount_msat, info.direction, info.status
if write_to_disk:
self.wallet.save_db()
@ -1916,11 +1916,12 @@ class LNWallet(LNWorker):
util.trigger_callback('invoice_status', self.wallet, key)
def set_request_status(self, payment_hash: bytes, status: int) -> None:
if self.get_payment_status(payment_hash) != status:
self.set_payment_status(payment_hash, status)
for key, req in self.wallet.receive_requests.items():
if req.is_lightning() and req.rhash == payment_hash.hex():
util.trigger_callback('request_status', self.wallet, key, status)
if self.get_payment_status(payment_hash) == status:
return
self.set_payment_status(payment_hash, status)
req = self.wallet.get_request_by_rhash(payment_hash.hex())
key = self.wallet.get_key_for_receive_request(req)
util.trigger_callback('request_status', self.wallet, key, status)
def set_payment_status(self, payment_hash: bytes, status: int) -> None:
info = self.get_payment_info(payment_hash)
@ -2060,13 +2061,14 @@ class LNWallet(LNWorker):
trampoline_hints.append(('t', (node_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta)))
return routing_hints, trampoline_hints
def delete_payment(self, payment_hash_hex: str):
try:
with self.lock:
del self.payments[payment_hash_hex]
except KeyError:
return
self.wallet.save_db()
def delete_payment_info(self, payment_hash_hex: str):
# This method is called when an invoice or request is deleted by the user.
# The GUI only lets the user delete invoices or requests that have not been paid.
# Once an invoice/request has been paid, it is part of the history,
# and get_lightning_history assumes that payment_info is there.
assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID
with self.lock:
self.payment_info.pop(payment_hash_hex, None)
def get_balance(self, frozen=False):
with self.lock:

8
electrum/tests/test_lnpeer.py

@ -105,6 +105,12 @@ class MockWallet:
receive_requests = {}
adb = MockADB()
def get_request_by_rhash(self, rhash):
pass
def get_key_for_receive_request(self, x):
pass
def set_label(self, x, y):
pass
@ -134,7 +140,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.lnwatcher = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
self.payments = {}
self.payment_info = {}
self.logs = defaultdict(list)
self.wallet = MockWallet()
self.features = LnFeatures(0)

58
electrum/wallet.py

@ -316,6 +316,7 @@ class Abstract_Wallet(ABC, Logger):
self._freeze_lock = threading.RLock() # for mutating/iterating frozen_{addresses,coins}
self._init_requests_rhash_index()
self._prepare_onchain_invoice_paid_detection()
self.calc_unused_change_addresses()
# save wallet type the first time
@ -967,6 +968,7 @@ class Abstract_Wallet(ABC, Logger):
def clear_requests(self):
self.receive_requests.clear()
self._requests_rhash_to_key.clear()
self.save_db()
def get_invoices(self):
@ -1017,6 +1019,12 @@ class Abstract_Wallet(ABC, Logger):
assert isinstance(inv, Invoice), f"unexpected type {type(inv)}"
return invoices
def _init_requests_rhash_index(self):
self._requests_rhash_to_key = {}
for key, req in self.receive_requests.items():
if req.is_lightning():
self._requests_rhash_to_key[req.rhash] = key
def _prepare_onchain_invoice_paid_detection(self):
# scriptpubkey -> list(invoice_keys)
self._invoices_from_scriptpubkey_map = defaultdict(set) # type: Dict[bytes, Set[str]]
@ -1326,6 +1334,13 @@ class Abstract_Wallet(ABC, Logger):
return ', '.join(labels)
return ''
def _get_default_label_for_rhash(self, rhash: str) -> str:
req = self.get_request_by_rhash(rhash)
return req.message if req else ''
def get_label_for_rhash(self, rhash: str) -> str:
return self._labels.get(rhash) or self._get_default_label_for_rhash(rhash)
def get_all_labels(self) -> Dict[str, str]:
with self.lock:
return copy.copy(self._labels)
@ -2327,10 +2342,14 @@ class Abstract_Wallet(ABC, Logger):
return self.check_expired_status(r, status)
def get_request(self, key):
return self.receive_requests.get(key)
return self.receive_requests.get(key) or self.get_request_by_rhash(key)
def get_request_by_rhash(self, rhash):
key = self._requests_rhash_to_key.get(rhash)
return self.receive_requests.get(key) if key else None
def get_formatted_request(self, key):
x = self.receive_requests.get(key)
x = self.get_request.get(key)
if x:
return self.export_request(x)
@ -2473,32 +2492,31 @@ class Abstract_Wallet(ABC, Logger):
key = self.get_key_for_receive_request(req, sanity_checks=True)
message = req.message
self.receive_requests[key] = req
self.set_label(key, message) # should be a default label
if req.is_lightning():
self._requests_rhash_to_key[req.rhash] = key
if write_to_disk:
self.save_db()
return key
def delete_request(self, key):
""" lightning or on-chain """
if key in self.receive_requests:
self.remove_payment_request(key)
elif self.lnworker:
self.lnworker.delete_payment(key)
req = self.receive_requests.pop(key, None)
if req is None:
return
if req.is_lightning():
self._requests_rhash_to_key.pop(req.rhash)
if req.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(req.rhash)
self.save_db()
def delete_invoice(self, key):
""" lightning or on-chain """
if key in self.invoices:
self.invoices.pop(key)
elif self.lnworker:
self.lnworker.delete_payment(key)
def remove_payment_request(self, addr) -> bool:
found = False
if addr in self.receive_requests:
found = True
self.receive_requests.pop(addr)
self.save_db()
return found
inv = self.invoices.pop(key, None)
if inv is None:
return
if inv.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(inv.rhash)
self.save_db()
def get_sorted_requests(self) -> List[Invoice]:
""" sorted by timestamp """
@ -2914,7 +2932,7 @@ class Imported_Wallet(Simple_Wallet):
for tx_hash in transactions_to_remove:
self.adb._remove_transaction(tx_hash)
self.set_label(address, None)
self.remove_payment_request(address)
self.delete_request(address)
self.set_frozen_state_of_addresses([address], False)
pubkey = self.get_public_key(address)
self.db.remove_imported_address(address)

Loading…
Cancel
Save