diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 6bd784b40..c609e69b6 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1536,13 +1536,14 @@ class Peer(Logger): # TODO fail here if invoice has set PAYMENT_SECRET_REQ payment_secret_from_onion = None - mpp_status = self.lnworker.add_received_htlc(chan.short_channel_id, htlc, total_msat) - if mpp_status is None: - return None, None - if mpp_status is False: - log_fail_reason(f"MPP_TIMEOUT") - raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') - assert mpp_status is True + if total_msat > amt_to_forward: + mpp_status = self.lnworker.add_received_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat) + if mpp_status is None: + return None, None + if mpp_status is False: + log_fail_reason(f"MPP_TIMEOUT") + raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') + assert mpp_status is True # if there is a trampoline_onion, maybe_fulfill_htlc will be called again if processed_onion.trampoline_onion_packet: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index a052ed402..bf899ff64 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1594,28 +1594,33 @@ class LNWallet(LNWorker): self.payments[key] = info.amount_msat, info.direction, info.status self.wallet.save_db() - def add_received_htlc(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: + def add_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: """ return MPP status: True (accepted), False (expired) or None """ payment_hash = htlc.payment_hash - is_accepted = (self.get_payment_status(payment_hash) == PR_PAID) - is_expired, htlc_set = self.received_htlcs.get(payment_hash, (False, set())) + is_expired, is_accepted, htlc_set = self.received_htlcs.get(payment_secret, (False, False, set())) + if self.get_payment_status(payment_hash) == PR_PAID: + # payment_status is persisted + is_accepted = True + is_expired = False key = (short_channel_id, htlc) if key not in htlc_set: htlc_set.add(key) if not is_accepted and not is_expired: total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) - is_expired = time.time() - first_timestamp > MPP_EXPIRY - if not is_expired and total == expected_msat: + if time.time() - first_timestamp > MPP_EXPIRY: + is_expired = True + elif total == expected_msat: is_accepted = True - self.set_payment_status(payment_hash, PR_PAID) - util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) + if self.get_payment_info(payment_hash) is not None: + self.set_payment_status(payment_hash, PR_PAID) + util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) if is_accepted or is_expired: htlc_set.remove(key) if len(htlc_set) > 0: - self.received_htlcs[payment_hash] = is_expired, htlc_set - elif payment_hash in self.received_htlcs: - self.received_htlcs.pop(payment_hash) + self.received_htlcs[payment_secret] = is_expired, is_accepted, htlc_set + elif payment_secret in self.received_htlcs: + self.received_htlcs.pop(payment_secret) return True if is_accepted else (False if is_expired else None) def get_payment_status(self, payment_hash):