diff --git a/electrum/gui/qt/invoice_list.py b/electrum/gui/qt/invoice_list.py index 451e09662..82278666e 100644 --- a/electrum/gui/qt/invoice_list.py +++ b/electrum/gui/qt/invoice_list.py @@ -31,7 +31,7 @@ from PyQt5.QtWidgets import QHeaderView, QMenu from electrum.i18n import _ from electrum.util import format_time, pr_tooltips, PR_UNPAID -from electrum.lnutil import lndecode +from electrum.lnutil import lndecode, RECEIVED from electrum.bitcoin import COIN from electrum import constants @@ -92,8 +92,8 @@ class InvoiceList(MyTreeView): self.model().insertRow(idx, items) lnworker = self.parent.wallet.lnworker - for key, (invoice, is_received) in lnworker.invoices.items(): - if is_received: + for key, (invoice, direction, is_paid) in lnworker.invoices.items(): + if direction == RECEIVED: continue status = lnworker.get_invoice_status(key) lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) diff --git a/electrum/gui/qt/request_list.py b/electrum/gui/qt/request_list.py index 86f732e28..b0a09844d 100644 --- a/electrum/gui/qt/request_list.py +++ b/electrum/gui/qt/request_list.py @@ -94,7 +94,7 @@ class RequestList(MyTreeView): return req = self.parent.get_request_URI(key) elif request_type == REQUEST_TYPE_LN: - req, is_received = self.wallet.lnworker.invoices.get(key) or (None, None) + req, direction, is_paid = self.wallet.lnworker.invoices.get(key) or (None, None) if req is None: self.update() return @@ -136,8 +136,8 @@ class RequestList(MyTreeView): self.filter() # lightning lnworker = self.wallet.lnworker - for key, (invoice, is_received) in lnworker.invoices.items(): - if not is_received: + for key, (invoice, direction, is_paid) in lnworker.invoices.items(): + if direction == SENT: continue status = lnworker.get_invoice_status(key) lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 8312a791f..e27ec3eb2 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -613,9 +613,8 @@ class Channel(PrintError): assert htlc.payment_hash == sha256(preimage) assert htlc_id not in log['settles'] self.hm.send_settle(htlc_id) - # save timestamp in LNWorker.preimages if self.lnworker: - self.lnworker.save_preimage(htlc.payment_hash, preimage, timestamp=int(time.time())) + self.lnworker.set_paid(htlc.payment_hash) def receive_htlc_settle(self, preimage, htlc_id): self.print_error("receive_htlc_settle") @@ -625,7 +624,8 @@ class Channel(PrintError): assert htlc_id not in log['settles'] self.hm.recv_settle(htlc_id) if self.lnworker: - self.lnworker.save_preimage(htlc.payment_hash, preimage, timestamp=int(time.time())) + self.lnworker.save_preimage(htlc.payment_hash, preimage) + self.lnworker.set_paid(htlc.payment_hash) def fail_htlc(self, htlc_id): self.print_error("fail_htlc") diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index ec20c0db9..24e4e8199 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -864,7 +864,7 @@ class Peer(PrintError): secret_key = os.urandom(32) onion = new_onion_packet([x.node_id for x in route], secret_key, hops_data, associated_data=payment_hash) # create htlc - htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv) + htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time())) htlc = chan.add_htlc(htlc) remote_ctn = chan.get_current_ctn(REMOTE) chan.onion_keys[htlc.htlc_id] = secret_key @@ -943,7 +943,7 @@ class Peer(PrintError): if cltv_expiry >= 500_000_000: pass # TODO fail the channel # add htlc - htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, payment_hash=payment_hash, cltv_expiry=cltv_expiry) + htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, payment_hash=payment_hash, cltv_expiry=cltv_expiry, timestamp=int(time.time())) htlc = chan.receive_htlc(htlc) local_ctn = chan.get_current_ctn(LOCAL) remote_ctn = chan.get_current_ctn(REMOTE) @@ -980,7 +980,7 @@ class Peer(PrintError): self.print_error('forwarding htlc to', next_chan.node_id) next_cltv_expiry = int.from_bytes(dph.outgoing_cltv_value, 'big') next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big') - next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry) + next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry, timestamp=int(time.time())) next_htlc = next_chan.add_htlc(next_htlc) next_remote_ctn = next_chan.get_current_ctn(REMOTE) next_peer.send_message( diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 704287ea2..f9f175b09 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -666,7 +666,7 @@ def format_short_channel_id(short_channel_id: Optional[bytes]): + 'x' + str(int.from_bytes(short_channel_id[6:], 'big')) -class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])): +class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id', 'timestamp'])): # note: typing.NamedTuple cannot be used because we are overriding __new__ __slots__ = () diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 4e0fff88c..38d2b1d16 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -71,8 +71,8 @@ class LNWorker(PrintError): def __init__(self, wallet: 'Abstract_Wallet'): self.wallet = wallet self.storage = wallet.storage - self.invoices = self.storage.get('lightning_invoices', {}) # RHASH -> (invoice, is_received) - self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> (preimage, timestamp) + self.invoices = self.storage.get('lightning_invoices', {}) # RHASH -> (invoice, direction, is_paid) + self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage self.sweep_address = wallet.get_receiving_address() self.lock = threading.RLock() self.ln_keystore = self._read_ln_keystore() @@ -133,13 +133,11 @@ class LNWorker(PrintError): timestamp = int(time.time()) self.network.trigger_callback('ln_payment_completed', timestamp, direction, htlc, preimage, chan_id) - def get_invoice_status(self, payment_hash): - if payment_hash not in self.preimages: + def get_invoice_status(self, key): + if key not in self.invoices: return PR_UNKNOWN - preimage, timestamp = self.preimages.get(payment_hash) - if timestamp is None: - return PR_UNPAID - return PR_PAID + invoice, direction, is_paid = self.invoices[key] + return PR_PAID if is_paid else PR_UNPAID def get_payments(self): # return one item per payment_hash @@ -159,14 +157,15 @@ class LNWorker(PrintError): direction = 'sent' if _direction == SENT else 'received' amount_msat= int(_direction) * htlc.amount_msat label = '' + timestamp = htlc.timestamp else: # assume forwarding direction = 'forwarding' amount_msat = sum([int(_direction) * htlc.amount_msat for chan_id, htlc, _direction, status in plist]) status = '' label = _('Forwarding') + timestamp = min([htlc.timestamp for chan_id, htlc, _direction, status in plist]) - timestamp = self.preimages[payment_hash][1] if payment_hash in self.preimages else None item = { 'type': 'payment', 'label': label, @@ -490,7 +489,7 @@ class LNWorker(PrintError): if not chan: raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) peer = self.peers[route[0].node_id] - self.save_invoice(addr.paymenthash, pay_req, SENT) + self.save_invoice(addr.paymenthash, pay_req, SENT, is_paid=False) htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry()) self.network.trigger_callback('htlc_added', htlc, addr, SENT) @@ -574,40 +573,42 @@ class LNWorker(PrintError): ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE)] + routing_hints), self.node_keypair.privkey) - self.save_invoice(payment_hash, invoice, RECEIVED) - self.save_preimage(payment_hash, payment_preimage, timestamp=None) + self.save_invoice(payment_hash, invoice, RECEIVED, is_paid=False) + self.save_preimage(payment_hash, payment_preimage) return invoice - def save_preimage(self, payment_hash: bytes, preimage: bytes, *, timestamp: Optional[int]): + def save_preimage(self, payment_hash: bytes, preimage: bytes): assert sha256(preimage) == payment_hash - if timestamp is not None: - timestamp = int(timestamp) key = bh2u(payment_hash) - self.preimages[key] = bh2u(preimage), timestamp + self.preimages[key] = bh2u(preimage) self.storage.put('lightning_preimages', self.preimages) self.storage.write() - def get_preimage_and_timestamp(self, payment_hash: bytes) -> Tuple[bytes, int]: + def get_preimage(self, payment_hash: bytes) -> bytes: try: - preimage_hex, timestamp = self.preimages[bh2u(payment_hash)] - preimage = bfh(preimage_hex) + preimage = bfh(self.preimages[bh2u(payment_hash)]) assert sha256(preimage) == payment_hash - return preimage, timestamp + return preimage except KeyError as e: raise UnknownPaymentHash(payment_hash) from e - def get_preimage(self, payment_hash: bytes) -> bytes: - return self.get_preimage_and_timestamp(payment_hash)[0] - - def save_invoice(self, payment_hash:bytes, invoice, direction): + def save_invoice(self, payment_hash:bytes, invoice, direction, *, is_paid=False): key = bh2u(payment_hash) - self.invoices[key] = invoice, direction==RECEIVED + self.invoices[key] = invoice, direction, is_paid self.storage.put('lightning_invoices', self.invoices) self.storage.write() + def set_paid(self, payment_hash): + key = bh2u(payment_hash) + if key not in self.invoices: + # if we are forwarding + return + invoice, direction, _ = self.invoices[key] + self.save_invoice(payment_hash, invoice, direction, is_paid=True) + def get_invoice(self, payment_hash: bytes) -> LnAddr: try: - invoice, is_received = self.invoices[bh2u(payment_hash)] + invoice, direction, is_paid = self.invoices[bh2u(payment_hash)] return lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) except KeyError as e: raise UnknownPaymentHash(payment_hash) from e