Browse Source

MPP receive: allow payer to retry after mpp timeout

patch-4
ThomasV 4 years ago
parent
commit
7f61f22857
  1. 6
      electrum/lnpeer.py
  2. 44
      electrum/lnworker.py
  3. 4
      electrum/tests/test_lnpeer.py

6
electrum/lnpeer.py

@ -1576,10 +1576,10 @@ class Peer(Logger):
invoice_msat = info.amount_msat invoice_msat = info.amount_msat
if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat): if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat):
raise exc_incorrect_or_unknown_pd raise exc_incorrect_or_unknown_pd
accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, total_msat) mpp_status = self.lnworker.add_received_htlc(chan.short_channel_id, htlc, total_msat)
if accepted: if mpp_status == True:
return preimage return preimage
elif expired: elif mpp_status == False:
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
else: else:
return None return None

44
electrum/lnworker.py

@ -657,7 +657,7 @@ class LNWallet(LNWorker):
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.received_htlcs = defaultdict(set) # type: Dict[bytes, set] self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set
self.htlc_routes = dict() self.htlc_routes = dict()
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
@ -1682,24 +1682,30 @@ class LNWallet(LNWorker):
self.payments[key] = info.amount_msat, info.direction, info.status self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db() self.wallet.save_db()
def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int): def add_received_htlc(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
status = self.get_payment_status(htlc.payment_hash) """ return MPP status: True (accepted), False (expired) or None """
if status == PR_PAID: payment_hash = htlc.payment_hash
return True, None mpp_status, htlc_set = self.received_htlcs.get(payment_hash, (None, set()))
s = self.received_htlcs[htlc.payment_hash] key = (short_channel_id, htlc)
if (short_channel_id, htlc) not in s: if key not in htlc_set:
s.add((short_channel_id, htlc)) htlc_set.add(key)
total = sum([htlc.amount_msat for scid, htlc in s]) if mpp_status is None:
first_timestamp = min([htlc.timestamp for scid, htlc in s]) total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
expired = time.time() - first_timestamp > MPP_EXPIRY first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
if total == expected_msat and not expired: expired = time.time() - first_timestamp > MPP_EXPIRY
# status must be persisted if expired:
self.set_payment_status(htlc.payment_hash, PR_PAID) mpp_status = False
util.trigger_callback('request_status', self.wallet, htlc.payment_hash.hex(), PR_PAID) elif total == expected_msat:
return True, None mpp_status = True
if expired: self.set_payment_status(payment_hash, PR_PAID)
return None, True util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
return None, None if mpp_status is not None:
htlc_set.remove(key)
if len(htlc_set) > 0:
self.received_htlcs[payment_hash] = mpp_status, htlc_set
elif payment_hash in self.received_htlcs:
self.received_htlcs.pop(payment_hash)
return mpp_status
def get_payment_status(self, payment_hash): def get_payment_status(self, payment_hash):
info = self.get_payment_info(payment_hash) info = self.get_payment_info(payment_hash)

4
electrum/tests/test_lnpeer.py

@ -132,7 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
# used in tests # used in tests
self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle = asyncio.Event()
self.enable_htlc_settle.set() self.enable_htlc_settle.set()
self.received_htlcs = defaultdict(set) self.received_htlcs = dict()
self.sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs = defaultdict(asyncio.Queue)
self.htlc_routes = defaultdict(list) self.htlc_routes = defaultdict(list)
@ -170,7 +170,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
set_invoice_status = LNWallet.set_invoice_status set_invoice_status = LNWallet.set_invoice_status
set_payment_status = LNWallet.set_payment_status set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status get_payment_status = LNWallet.get_payment_status
htlc_received = LNWallet.htlc_received add_received_htlc = LNWallet.add_received_htlc
htlc_fulfilled = LNWallet.htlc_fulfilled htlc_fulfilled = LNWallet.htlc_fulfilled
htlc_failed = LNWallet.htlc_failed htlc_failed = LNWallet.htlc_failed
save_preimage = LNWallet.save_preimage save_preimage = LNWallet.save_preimage

Loading…
Cancel
Save