From b6b13217b4929c5701e18a1310b20caf01fb61e5 Mon Sep 17 00:00:00 2001
From: ThomasV <thomasv@electrum.org>
Date: Sat, 27 Feb 2021 20:26:58 +0100
Subject: [PATCH] lnworker: keep invoice status INFLIGHT as long as HTLCs are
 inflight

---
 electrum/lnworker.py          | 111 ++++++++++++++++++----------------
 electrum/tests/test_lnpeer.py |   3 +-
 2 files changed, 61 insertions(+), 53 deletions(-)

diff --git a/electrum/lnworker.py b/electrum/lnworker.py
index 1cf72c9e6..e383fab18 100644
--- a/electrum/lnworker.py
+++ b/electrum/lnworker.py
@@ -657,8 +657,8 @@ class LNWallet(LNWorker):
             self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
 
         self.sent_htlcs = defaultdict(asyncio.Queue)  # type: Dict[bytes, asyncio.Queue[HtlcLog]]
+        self.sent_htlcs_routes = dict()               # (RHASH, scid, htlc_id) -> route
         self.received_htlcs = dict()                  # RHASH -> mpp_status, htlc_set
-        self.htlc_routes = dict()
 
         self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
         # detect inflight payments
@@ -939,14 +939,13 @@ class LNWallet(LNWorker):
 
     @log_exceptions
     async def _open_channel_coroutine(
-            self,
-            *,
+            self, *,
             connect_str: str,
             funding_tx: PartialTransaction,
             funding_sat: int,
             push_sat: int,
-            password: Optional[str],
-    ) -> Tuple[Channel, PartialTransaction]:
+            password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
+
         peer = await self.add_peer(connect_str)
         coro = peer.channel_establishment_flow(
             funding_tx=funding_tx,
@@ -1053,7 +1052,6 @@ class LNWallet(LNWorker):
             random.shuffle(self.trampoline2_list)
 
         self.set_invoice_status(key, PR_INFLIGHT)
-        util.trigger_callback('invoice_status', self.wallet, key)
         try:
             await self.pay_to_node(
                 node_pubkey=invoice_pubkey,
@@ -1071,6 +1069,11 @@ class LNWallet(LNWorker):
             self.logger.exception('')
             success = False
             reason = str(e)
+        # keep invoice status INFLIGHT as long as HTLCs are inflight
+        # maybe we could add an extra state for the waiting time.
+        while payment_hash in self.get_payments(status='inflight'):
+            self.logger.info('waiting for inflight HTLCs...')
+            await self.sent_htlcs[payment_hash].get()
         if success:
             self.set_invoice_status(key, PR_PAID)
             util.trigger_callback('payment_succeeded', self.wallet, key)
@@ -1081,8 +1084,7 @@ class LNWallet(LNWorker):
         return success, log
 
     async def pay_to_node(
-            self,
-            *,
+            self, *,
             node_pubkey: bytes,
             payment_hash: bytes,
             payment_secret: Optional[bytes],
@@ -1095,8 +1097,7 @@ class LNWallet(LNWorker):
             full_path: LNPaymentPath = None,
             trampoline_onion=None,
             trampoline_fee=None,
-            trampoline_cltv_delta=None,
-    ) -> None:
+            trampoline_cltv_delta=None) -> None:
 
         if trampoline_onion:
             # todo: compare to the fee of the actual route we found
@@ -1119,7 +1120,7 @@ class LNWallet(LNWorker):
                 # 2. send htlcs
                 for route, amount_msat in routes:
                     await self.pay_to_route(
-                        route,
+                        route=route,
                         amount_msat=amount_msat,
                         total_msat=amount_to_pay,
                         payment_hash=payment_hash,
@@ -1142,16 +1143,15 @@ class LNWallet(LNWorker):
             self.handle_error_code_from_failed_htlc(htlc_log)
 
     async def pay_to_route(
-            self,
+            self, *,
             route: LNPaymentRoute,
-            *,
             amount_msat: int,
             total_msat: int,
             payment_hash: bytes,
             payment_secret: Optional[bytes],
             min_cltv_expiry: int,
-            trampoline_onion: bytes = None,
-    ) -> None:
+            trampoline_onion: bytes = None) -> None:
+
         # send a single htlc
         short_channel_id = route[0].short_channel_id
         chan = self.get_channel_by_short_id(short_channel_id)
@@ -1168,7 +1168,7 @@ class LNWallet(LNWorker):
             min_final_cltv_expiry=min_cltv_expiry,
             payment_secret=payment_secret,
             fwd_trampoline_onion=trampoline_onion)
-        self.htlc_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
+        self.sent_htlcs_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
         util.trigger_callback('htlc_added', chan, htlc, SENT)
 
     def handle_error_code_from_failed_htlc(self, htlc_log):
@@ -1729,6 +1729,7 @@ class LNWallet(LNWorker):
             self.inflight_payments.remove(key)
         if status in SAVED_PR_STATUS:
             self.set_payment_status(bfh(key), status)
+        util.trigger_callback('invoice_status', self.wallet, key)
 
     def set_payment_status(self, payment_hash: bytes, status):
         info = self.get_payment_info(payment_hash)
@@ -1739,54 +1740,60 @@ class LNWallet(LNWorker):
         self.save_payment_info(info)
 
     def htlc_fulfilled(self, chan, payment_hash: bytes, htlc_id:int, amount_msat:int):
-        route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
-        htlc_log = HtlcLog(
-            success=True,
-            route=route,
-            amount_msat=amount_msat)
-        q = self.sent_htlcs[payment_hash]
-        q.put_nowait(htlc_log)
         util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id)
+        q = self.sent_htlcs.get(payment_hash)
+        if q:
+            route = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
+            htlc_log = HtlcLog(
+                success=True,
+                route=route,
+                amount_msat=amount_msat)
+            q.put_nowait(htlc_log)
+        else:
+            if payment_hash not in self.get_payments(status='inflight'):
+                key = payment_hash.hex()
+                self.set_invoice_status(key, PR_PAID)
+                util.trigger_callback('payment_succeeded', self.wallet, key)
 
     def htlc_failed(
             self,
-            chan,
+            chan: Channel,
             payment_hash: bytes,
             htlc_id: int,
             amount_msat:int,
             error_bytes: Optional[bytes],
             failure_message: Optional['OnionRoutingFailure']):
 
-        route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
-        if not route:
-            self.logger.info(f"received unknown htlc_failed, probably from previous session")
-            return
-        if error_bytes:
-            self.logger.info(f" {(error_bytes, route, htlc_id)}")
-            # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
-            try:
-                failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id)
-            except Exception as e:
+        util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)
+        q = self.sent_htlcs.get(payment_hash)
+        if q:
+            route = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
+            if error_bytes:
+                # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
+                try:
+                    failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id)
+                except Exception as e:
+                    sender_idx = None
+                    failure_message = OnionRoutingFailure(-1, str(e))
+            else:
+                # probably got "update_fail_malformed_htlc". well... who to penalise now?
+                assert failure_message is not None
                 sender_idx = None
-                failure_message = OnionRoutingFailure(-1, str(e))
+            self.logger.info(f"htlc_failed {failure_message}")
+            htlc_log = HtlcLog(
+                success=False,
+                route=route,
+                amount_msat=amount_msat,
+                error_bytes=error_bytes,
+                failure_msg=failure_message,
+                sender_idx=sender_idx)
+            q.put_nowait(htlc_log)
         else:
-            # probably got "update_fail_malformed_htlc". well... who to penalise now?
-            assert failure_message is not None
-            sender_idx = None
-
-        htlc_log = HtlcLog(
-            success=False,
-            route=route,
-            amount_msat=amount_msat,
-            error_bytes=error_bytes,
-            failure_msg=failure_message,
-            sender_idx=sender_idx)
-
-        q = self.sent_htlcs[payment_hash]
-        q.put_nowait(htlc_log)
-        util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)
-
-
+            self.logger.info(f"received unknown htlc_failed, probably from previous session")
+            if payment_hash not in self.get_payments(status='inflight'):
+                key = payment_hash.hex()
+                self.set_invoice_status(key, PR_UNPAID)
+                util.trigger_callback('payment_failed', self.wallet, key, '')
 
     async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
         """calculate routing hints (BOLT-11 'r' field)"""
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
index 9df8e1e60..642c9ca12 100644
--- a/electrum/tests/test_lnpeer.py
+++ b/electrum/tests/test_lnpeer.py
@@ -165,6 +165,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
 
     inflight_payments = set()
     preimages = {}
+    get_payments = LNWallet.get_payments
     get_payment_info = LNWallet.get_payment_info
     save_payment_info = LNWallet.save_payment_info
     set_invoice_status = LNWallet.set_invoice_status
@@ -776,7 +777,7 @@ class TestPeer(ElectrumTestCase):
             payment_hash = lnaddr.paymenthash
             payment_secret = lnaddr.payment_secret
             pay = w1.pay_to_route(
-                route,
+                route=route,
                 amount_msat=amount_msat,
                 total_msat=amount_msat,
                 payment_hash=payment_hash,