|
|
@ -657,7 +657,7 @@ class LNWallet(LNWorker): |
|
|
|
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) |
|
|
|
|
|
|
|
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] |
|
|
|
self.received_htlcs = defaultdict(set) # type: Dict[bytes, set] |
|
|
|
self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set |
|
|
|
self.htlc_routes = dict() |
|
|
|
|
|
|
|
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) |
|
|
@ -1682,24 +1682,30 @@ class LNWallet(LNWorker): |
|
|
|
self.payments[key] = info.amount_msat, info.direction, info.status |
|
|
|
self.wallet.save_db() |
|
|
|
|
|
|
|
def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int): |
|
|
|
status = self.get_payment_status(htlc.payment_hash) |
|
|
|
if status == PR_PAID: |
|
|
|
return True, None |
|
|
|
s = self.received_htlcs[htlc.payment_hash] |
|
|
|
if (short_channel_id, htlc) not in s: |
|
|
|
s.add((short_channel_id, htlc)) |
|
|
|
total = sum([htlc.amount_msat for scid, htlc in s]) |
|
|
|
first_timestamp = min([htlc.timestamp for scid, htlc in s]) |
|
|
|
expired = time.time() - first_timestamp > MPP_EXPIRY |
|
|
|
if total == expected_msat and not expired: |
|
|
|
# status must be persisted |
|
|
|
self.set_payment_status(htlc.payment_hash, PR_PAID) |
|
|
|
util.trigger_callback('request_status', self.wallet, htlc.payment_hash.hex(), PR_PAID) |
|
|
|
return True, None |
|
|
|
if expired: |
|
|
|
return None, True |
|
|
|
return None, None |
|
|
|
def add_received_htlc(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: |
|
|
|
""" return MPP status: True (accepted), False (expired) or None """ |
|
|
|
payment_hash = htlc.payment_hash |
|
|
|
mpp_status, htlc_set = self.received_htlcs.get(payment_hash, (None, set())) |
|
|
|
key = (short_channel_id, htlc) |
|
|
|
if key not in htlc_set: |
|
|
|
htlc_set.add(key) |
|
|
|
if mpp_status is None: |
|
|
|
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) |
|
|
|
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) |
|
|
|
expired = time.time() - first_timestamp > MPP_EXPIRY |
|
|
|
if expired: |
|
|
|
mpp_status = False |
|
|
|
elif total == expected_msat: |
|
|
|
mpp_status = True |
|
|
|
self.set_payment_status(payment_hash, PR_PAID) |
|
|
|
util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) |
|
|
|
if mpp_status is not None: |
|
|
|
htlc_set.remove(key) |
|
|
|
if len(htlc_set) > 0: |
|
|
|
self.received_htlcs[payment_hash] = mpp_status, htlc_set |
|
|
|
elif payment_hash in self.received_htlcs: |
|
|
|
self.received_htlcs.pop(payment_hash) |
|
|
|
return mpp_status |
|
|
|
|
|
|
|
def get_payment_status(self, payment_hash): |
|
|
|
info = self.get_payment_info(payment_hash) |
|
|
|