Browse Source

add TODO, rename check_received_mpp_htlc

patch-4
ThomasV 4 years ago
parent
commit
533d796a41
  1. 3
      electrum/lnpeer.py
  2. 12
      electrum/lnworker.py
  3. 4
      electrum/tests/test_lnpeer.py

3
electrum/lnpeer.py

@ -1538,7 +1538,7 @@ class Peer(Logger):
payment_secret_from_onion = None payment_secret_from_onion = None
if total_msat > amt_to_forward: if total_msat > amt_to_forward:
mpp_status = self.lnworker.add_received_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat) mpp_status = self.lnworker.check_received_mpp_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
if mpp_status is None: if mpp_status is None:
return None, None return None, None
if mpp_status is False: if mpp_status is False:
@ -1548,6 +1548,7 @@ class Peer(Logger):
# if there is a trampoline_onion, maybe_fulfill_htlc will be called again # if there is a trampoline_onion, maybe_fulfill_htlc will be called again
if processed_onion.trampoline_onion_packet: if processed_onion.trampoline_onion_packet:
# TODO: we should check that all trampoline_onions are the same
return None, processed_onion.trampoline_onion_packet return None, processed_onion.trampoline_onion_packet
info = self.lnworker.get_payment_info(htlc.payment_hash) info = self.lnworker.get_payment_info(htlc.payment_hash)

12
electrum/lnworker.py

@ -609,7 +609,7 @@ class LNWallet(LNWorker):
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat
self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed) self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed)
self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set self.received_mpp_htlcs = dict() # RHASH -> mpp_status, htlc_set
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
# detect inflight payments # detect inflight payments
@ -1621,10 +1621,10 @@ class LNWallet(LNWorker):
self.payments[key] = info.amount_msat, info.direction, info.status self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db() self.wallet.save_db()
def add_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: def check_received_mpp_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None """ """ return MPP status: True (accepted), False (expired) or None """
payment_hash = htlc.payment_hash payment_hash = htlc.payment_hash
is_expired, is_accepted, htlc_set = self.received_htlcs.get(payment_secret, (False, False, set())) is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
if self.get_payment_status(payment_hash) == PR_PAID: if self.get_payment_status(payment_hash) == PR_PAID:
# payment_status is persisted # payment_status is persisted
is_accepted = True is_accepted = True
@ -1642,9 +1642,9 @@ class LNWallet(LNWorker):
if is_accepted or is_expired: if is_accepted or is_expired:
htlc_set.remove(key) htlc_set.remove(key)
if len(htlc_set) > 0: if len(htlc_set) > 0:
self.received_htlcs[payment_secret] = is_expired, is_accepted, htlc_set self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set
elif payment_secret in self.received_htlcs: elif payment_secret in self.received_mpp_htlcs:
self.received_htlcs.pop(payment_secret) self.received_mpp_htlcs.pop(payment_secret)
return True if is_accepted else (False if is_expired else None) return True if is_accepted else (False if is_expired else None)
def get_payment_status(self, payment_hash: bytes) -> int: def get_payment_status(self, payment_hash: bytes) -> int:

4
electrum/tests/test_lnpeer.py

@ -140,7 +140,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.enable_htlc_settle.set() self.enable_htlc_settle.set()
self.enable_htlc_forwarding = asyncio.Event() self.enable_htlc_forwarding = asyncio.Event()
self.enable_htlc_forwarding.set() self.enable_htlc_forwarding.set()
self.received_htlcs = dict() self.received_mpp_htlcs = dict()
self.sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs = defaultdict(asyncio.Queue)
self.sent_htlcs_routes = dict() self.sent_htlcs_routes = dict()
self.sent_buckets = defaultdict(set) self.sent_buckets = defaultdict(set)
@ -194,7 +194,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
set_request_status = LNWallet.set_request_status set_request_status = LNWallet.set_request_status
set_payment_status = LNWallet.set_payment_status set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status get_payment_status = LNWallet.get_payment_status
add_received_htlc = LNWallet.add_received_htlc check_received_mpp_htlc = LNWallet.check_received_mpp_htlc
htlc_fulfilled = LNWallet.htlc_fulfilled htlc_fulfilled = LNWallet.htlc_fulfilled
htlc_failed = LNWallet.htlc_failed htlc_failed = LNWallet.htlc_failed
save_preimage = LNWallet.save_preimage save_preimage = LNWallet.save_preimage

Loading…
Cancel
Save