Browse Source

wallet: make "invoices" and "receive_requests" private

Other modules should use getters such as "get_request(key)" or "get_unpaid_requests()",
direct access is error-prone.
patch-4
SomberNight 3 years ago
parent
commit
d067e0e314
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 4
      electrum/gui/qt/invoice_list.py
  2. 43
      electrum/wallet.py

4
electrum/gui/qt/invoice_list.py

@ -82,7 +82,7 @@ class InvoiceList(MyTreeView):
def refresh_row(self, key, row): def refresh_row(self, key, row):
assert row is not None assert row is not None
invoice = self.wallet.invoices.get(key) invoice = self.wallet.get_invoice(key)
if invoice is None: if invoice is None:
return return
model = self.std_model model = self.std_model
@ -141,7 +141,7 @@ class InvoiceList(MyTreeView):
items = self.selected_in_column(0) items = self.selected_in_column(0)
if len(items)>1: if len(items)>1:
keys = [item.data(ROLE_REQUEST_ID) for item in items] keys = [item.data(ROLE_REQUEST_ID) for item in items]
invoices = [wallet.invoices.get(key) for key in keys] invoices = [wallet.get_invoice(key) for key in keys]
can_batch_pay = all([not i.is_lightning() and wallet.get_invoice_status(i) == PR_UNPAID for i in invoices]) can_batch_pay = all([not i.is_lightning() and wallet.get_invoice_status(i) == PR_UNPAID for i in invoices])
menu = QMenu(self) menu = QMenu(self)
if can_batch_pay: if can_batch_pay:

43
electrum/wallet.py

@ -312,8 +312,8 @@ class Abstract_Wallet(ABC, Logger, EventListener):
self._frozen_addresses = set(db.get('frozen_addresses', [])) self._frozen_addresses = set(db.get('frozen_addresses', []))
self._frozen_coins = db.get_dict('frozen_coins') # type: Dict[str, bool] self._frozen_coins = db.get_dict('frozen_coins') # type: Dict[str, bool]
self.fiat_value = db.get_dict('fiat_value') self.fiat_value = db.get_dict('fiat_value')
self.receive_requests = db.get_dict('payment_requests') # type: Dict[str, Invoice] self._receive_requests = db.get_dict('payment_requests') # type: Dict[str, Invoice]
self.invoices = db.get_dict('invoices') # type: Dict[str, Invoice] self._invoices = db.get_dict('invoices') # type: Dict[str, Invoice]
self._reserved_addresses = set(db.get('reserved_addresses', [])) self._reserved_addresses = set(db.get('reserved_addresses', []))
self._freeze_lock = threading.RLock() # for mutating/iterating frozen_{addresses,coins} self._freeze_lock = threading.RLock() # for mutating/iterating frozen_{addresses,coins}
@ -971,20 +971,20 @@ class Abstract_Wallet(ABC, Logger, EventListener):
with self.transaction_lock: with self.transaction_lock:
for txout in invoice.get_outputs(): for txout in invoice.get_outputs():
self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(key) self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(key)
self.invoices[key] = invoice self._invoices[key] = invoice
self.save_db() self.save_db()
def clear_invoices(self): def clear_invoices(self):
self.invoices.clear() self._invoices.clear()
self.save_db() self.save_db()
def clear_requests(self): def clear_requests(self):
self.receive_requests.clear() self._receive_requests.clear()
self._requests_addr_to_rhash.clear() self._requests_addr_to_rhash.clear()
self.save_db() self.save_db()
def get_invoices(self): def get_invoices(self):
out = list(self.invoices.values()) out = list(self._invoices.values())
out.sort(key=lambda x:x.time) out.sort(key=lambda x:x.time)
return out return out
@ -993,7 +993,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return [x for x in invoices if self.get_invoice_status(x) != PR_PAID] return [x for x in invoices if self.get_invoice_status(x) != PR_PAID]
def get_invoice(self, key): def get_invoice(self, key):
return self.invoices.get(key) return self._invoices.get(key)
def import_requests(self, path): def import_requests(self, path):
data = read_json_file(path) data = read_json_file(path)
@ -1002,7 +1002,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
self.add_payment_request(req) self.add_payment_request(req)
def export_requests(self, path): def export_requests(self, path):
write_json_file(path, list(self.receive_requests.values())) write_json_file(path, list(self._receive_requests.values()))
def import_invoices(self, path): def import_invoices(self, path):
data = read_json_file(path) data = read_json_file(path)
@ -1011,7 +1011,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
self.save_invoice(invoice) self.save_invoice(invoice)
def export_invoices(self, path): def export_invoices(self, path):
write_json_file(path, list(self.invoices.values())) write_json_file(path, list(self._invoices.values()))
def _get_relevant_invoice_keys_for_tx(self, tx: Transaction) -> Set[str]: def _get_relevant_invoice_keys_for_tx(self, tx: Transaction) -> Set[str]:
relevant_invoice_keys = set() relevant_invoice_keys = set()
@ -1019,7 +1019,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
for txout in tx.outputs(): for txout in tx.outputs():
for invoice_key in self._invoices_from_scriptpubkey_map.get(txout.scriptpubkey, set()): for invoice_key in self._invoices_from_scriptpubkey_map.get(txout.scriptpubkey, set()):
# note: the invoice might have been deleted since, so check now: # note: the invoice might have been deleted since, so check now:
if invoice_key in self.invoices: if invoice_key in self._invoices:
relevant_invoice_keys.add(invoice_key) relevant_invoice_keys.add(invoice_key)
return relevant_invoice_keys return relevant_invoice_keys
@ -1033,14 +1033,14 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def _init_requests_rhash_index(self): def _init_requests_rhash_index(self):
self._requests_addr_to_rhash = {} self._requests_addr_to_rhash = {}
for key, req in self.receive_requests.items(): for key, req in self._receive_requests.items():
if req.is_lightning() and (addr:=req.get_address()): if req.is_lightning() and (addr:=req.get_address()):
self._requests_addr_to_rhash[addr] = req.rhash self._requests_addr_to_rhash[addr] = req.rhash
def _prepare_onchain_invoice_paid_detection(self): def _prepare_onchain_invoice_paid_detection(self):
# scriptpubkey -> list(invoice_keys) # scriptpubkey -> list(invoice_keys)
self._invoices_from_scriptpubkey_map = defaultdict(set) # type: Dict[bytes, Set[str]] self._invoices_from_scriptpubkey_map = defaultdict(set) # type: Dict[bytes, Set[str]]
for invoice_key, invoice in self.invoices.items(): for invoice_key, invoice in self._invoices.items():
if invoice.is_lightning() and not invoice.get_address(): if invoice.is_lightning() and not invoice.get_address():
continue continue
for txout in invoice.get_outputs(): for txout in invoice.get_outputs():
@ -2343,12 +2343,12 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return self.check_expired_status(r, status) return self.check_expired_status(r, status)
def get_request(self, key): def get_request(self, key):
return self.receive_requests.get(key) or self.get_request_by_address(key) return self._receive_requests.get(key) or self.get_request_by_address(key)
def get_request_by_address(self, addr): def get_request_by_address(self, addr):
rhash = self._requests_addr_to_rhash.get(addr) rhash = self._requests_addr_to_rhash.get(addr)
if rhash: if rhash:
return self.receive_requests.get(rhash) return self._receive_requests.get(rhash)
def get_formatted_request(self, key): def get_formatted_request(self, key):
x = self.get_request(key) x = self.get_request(key)
@ -2469,7 +2469,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def sign_payment_request(self, key, alias, alias_addr, password): # FIXME this is broken def sign_payment_request(self, key, alias, alias_addr, password): # FIXME this is broken
raise raise
req = self.receive_requests.get(key) req = self._receive_requests.get(key)
assert not req.is_lightning() assert not req.is_lightning()
alias_privkey = self.export_private_key(alias_addr, password) alias_privkey = self.export_private_key(alias_addr, password)
pr = paymentrequest.make_unsigned_request(req) pr = paymentrequest.make_unsigned_request(req)
@ -2477,7 +2477,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
req.bip70 = pr.raw.hex() req.bip70 = pr.raw.hex()
req['name'] = pr.pki_data req['name'] = pr.pki_data
req['sig'] = bh2u(pr.signature) req['sig'] = bh2u(pr.signature)
self.receive_requests[key] = req self._receive_requests[key] = req
@classmethod @classmethod
def get_key_for_outgoing_invoice(cls, invoice: Invoice) -> str: def get_key_for_outgoing_invoice(cls, invoice: Invoice) -> str:
@ -2501,8 +2501,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def add_payment_request(self, req: Invoice, *, write_to_disk: bool = True): def add_payment_request(self, req: Invoice, *, write_to_disk: bool = True):
key = self.get_key_for_receive_request(req, sanity_checks=True) key = self.get_key_for_receive_request(req, sanity_checks=True)
message = req.message self._receive_requests[key] = req
self.receive_requests[key] = req
if req.is_lightning() and (addr:=req.get_address()): if req.is_lightning() and (addr:=req.get_address()):
self._requests_addr_to_rhash[addr] = req.rhash self._requests_addr_to_rhash[addr] = req.rhash
if write_to_disk: if write_to_disk:
@ -2511,7 +2510,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def delete_request(self, key): def delete_request(self, key):
""" lightning or on-chain """ """ lightning or on-chain """
req = self.receive_requests.pop(key, None) req = self._receive_requests.pop(key, None)
if req is None: if req is None:
return return
if req.is_lightning() and (addr:=req.get_address()): if req.is_lightning() and (addr:=req.get_address()):
@ -2522,7 +2521,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def delete_invoice(self, key): def delete_invoice(self, key):
""" lightning or on-chain """ """ lightning or on-chain """
inv = self.invoices.pop(key, None) inv = self._invoices.pop(key, None)
if inv is None: if inv is None:
return return
if inv.is_lightning() and self.lnworker: if inv.is_lightning() and self.lnworker:
@ -2531,13 +2530,13 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def get_sorted_requests(self) -> List[Invoice]: def get_sorted_requests(self) -> List[Invoice]:
""" sorted by timestamp """ """ sorted by timestamp """
out = [self.get_request(x) for x in self.receive_requests.keys()] out = [self.get_request(x) for x in self._receive_requests.keys()]
out = [x for x in out if x is not None] out = [x for x in out if x is not None]
out.sort(key=lambda x: x.time) out.sort(key=lambda x: x.time)
return out return out
def get_unpaid_requests(self): def get_unpaid_requests(self):
out = [self.get_request(x) for x in self.receive_requests.keys() if self.get_request_status(x) != PR_PAID] out = [self.get_request(x) for x in self._receive_requests.keys() if self.get_request_status(x) != PR_PAID]
out = [x for x in out if x is not None] out = [x for x in out if x is not None]
out.sort(key=lambda x: x.time) out.sort(key=lambda x: x.time)
return out return out

Loading…
Cancel
Save