Browse Source

fix trampoline forwarding: add_received_htlc must be indexed by payment_secret

patch-4
ThomasV 4 years ago
parent
commit
5207c40cc3
  1. 15
      electrum/lnpeer.py
  2. 25
      electrum/lnworker.py

15
electrum/lnpeer.py

@ -1536,13 +1536,14 @@ class Peer(Logger):
# TODO fail here if invoice has set PAYMENT_SECRET_REQ # TODO fail here if invoice has set PAYMENT_SECRET_REQ
payment_secret_from_onion = None payment_secret_from_onion = None
mpp_status = self.lnworker.add_received_htlc(chan.short_channel_id, htlc, total_msat) if total_msat > amt_to_forward:
if mpp_status is None: mpp_status = self.lnworker.add_received_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
return None, None if mpp_status is None:
if mpp_status is False: return None, None
log_fail_reason(f"MPP_TIMEOUT") if mpp_status is False:
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') log_fail_reason(f"MPP_TIMEOUT")
assert mpp_status is True raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
assert mpp_status is True
# if there is a trampoline_onion, maybe_fulfill_htlc will be called again # if there is a trampoline_onion, maybe_fulfill_htlc will be called again
if processed_onion.trampoline_onion_packet: if processed_onion.trampoline_onion_packet:

25
electrum/lnworker.py

@ -1594,28 +1594,33 @@ class LNWallet(LNWorker):
self.payments[key] = info.amount_msat, info.direction, info.status self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db() self.wallet.save_db()
def add_received_htlc(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: def add_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None """ """ return MPP status: True (accepted), False (expired) or None """
payment_hash = htlc.payment_hash payment_hash = htlc.payment_hash
is_accepted = (self.get_payment_status(payment_hash) == PR_PAID) is_expired, is_accepted, htlc_set = self.received_htlcs.get(payment_secret, (False, False, set()))
is_expired, htlc_set = self.received_htlcs.get(payment_hash, (False, set())) if self.get_payment_status(payment_hash) == PR_PAID:
# payment_status is persisted
is_accepted = True
is_expired = False
key = (short_channel_id, htlc) key = (short_channel_id, htlc)
if key not in htlc_set: if key not in htlc_set:
htlc_set.add(key) htlc_set.add(key)
if not is_accepted and not is_expired: if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
is_expired = time.time() - first_timestamp > MPP_EXPIRY if time.time() - first_timestamp > MPP_EXPIRY:
if not is_expired and total == expected_msat: is_expired = True
elif total == expected_msat:
is_accepted = True is_accepted = True
self.set_payment_status(payment_hash, PR_PAID) if self.get_payment_info(payment_hash) is not None:
util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) self.set_payment_status(payment_hash, PR_PAID)
util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
if is_accepted or is_expired: if is_accepted or is_expired:
htlc_set.remove(key) htlc_set.remove(key)
if len(htlc_set) > 0: if len(htlc_set) > 0:
self.received_htlcs[payment_hash] = is_expired, htlc_set self.received_htlcs[payment_secret] = is_expired, is_accepted, htlc_set
elif payment_hash in self.received_htlcs: elif payment_secret in self.received_htlcs:
self.received_htlcs.pop(payment_hash) self.received_htlcs.pop(payment_secret)
return True if is_accepted else (False if is_expired else None) return True if is_accepted else (False if is_expired else None)
def get_payment_status(self, payment_hash): def get_payment_status(self, payment_hash):

Loading…
Cancel
Save