Browse Source

cleanup, follow-up f28a2aae73

patch-4
ThomasV 4 years ago
parent
commit
34734bd229
  1. 21
      electrum/lnworker.py
  2. 8
      electrum/tests/test_lnpeer.py

21
electrum/lnworker.py

@ -580,11 +580,9 @@ class LNWallet(LNWorker):
for channel_id, c in random_shuffled_copy(channels.items()): for channel_id, c in random_shuffled_copy(channels.items()):
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[HtlcLog]] self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.pending_sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Future[HtlcLog]] self.received_htlcs = defaultdict(set) # type: Dict[bytes, set]
self.htlc_routes = dict()
self.pending_htlcs = defaultdict(set) # type: Dict[bytes, set]
self.htlc_routes = defaultdict(list)
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
# detect inflight payments # detect inflight payments
@ -955,19 +953,18 @@ class LNWallet(LNWorker):
lnaddr = self._check_invoice(invoice, amount_msat=amount_msat) lnaddr = self._check_invoice(invoice, amount_msat=amount_msat)
payment_hash = lnaddr.paymenthash payment_hash = lnaddr.paymenthash
key = payment_hash.hex() key = payment_hash.hex()
amount_msat = lnaddr.get_amount_msat() amount_to_pay = lnaddr.get_amount_msat()
status = self.get_payment_status(payment_hash) status = self.get_payment_status(payment_hash)
if status == PR_PAID: if status == PR_PAID:
raise PaymentFailure(_("This invoice has been paid already")) raise PaymentFailure(_("This invoice has been paid already"))
if status == PR_INFLIGHT: if status == PR_INFLIGHT:
raise PaymentFailure(_("A payment was already initiated for this invoice")) raise PaymentFailure(_("A payment was already initiated for this invoice"))
info = PaymentInfo(payment_hash, amount_msat, SENT, PR_UNPAID) info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
self.save_payment_info(info) self.save_payment_info(info)
self.wallet.set_label(key, lnaddr.get_description()) self.wallet.set_label(key, lnaddr.get_description())
self.logs[key] = log = [] self.logs[key] = log = []
success = False success = False
reason = '' reason = ''
amount_to_pay = lnaddr.get_amount_msat()
amount_inflight = 0 # what we sent in htlcs amount_inflight = 0 # what we sent in htlcs
self.set_invoice_status(key, PR_INFLIGHT) self.set_invoice_status(key, PR_INFLIGHT)
@ -990,7 +987,7 @@ class LNWallet(LNWorker):
amount_inflight += amount_msat amount_inflight += amount_msat
util.trigger_callback('invoice_status', self.wallet, key) util.trigger_callback('invoice_status', self.wallet, key)
# 3. await a queue # 3. await a queue
htlc_log = await self.pending_sent_htlcs[payment_hash].get() htlc_log = await self.sent_htlcs[payment_hash].get()
amount_inflight -= htlc_log.amount_msat amount_inflight -= htlc_log.amount_msat
log.append(htlc_log) log.append(htlc_log)
if htlc_log.success: if htlc_log.success:
@ -1318,7 +1315,7 @@ class LNWallet(LNWorker):
status = self.get_payment_status(htlc.payment_hash) status = self.get_payment_status(htlc.payment_hash)
if status == PR_PAID: if status == PR_PAID:
return True, None return True, None
s = self.pending_htlcs[htlc.payment_hash] s = self.received_htlcs[htlc.payment_hash]
if (short_channel_id, htlc) not in s: if (short_channel_id, htlc) not in s:
s.add((short_channel_id, htlc)) s.add((short_channel_id, htlc))
total = sum([htlc.amount_msat for scid, htlc in s]) total = sum([htlc.amount_msat for scid, htlc in s])
@ -1370,7 +1367,7 @@ class LNWallet(LNWorker):
success=True, success=True,
route=route, route=route,
amount_msat=amount_msat) amount_msat=amount_msat)
q = self.pending_sent_htlcs[payment_hash] q = self.sent_htlcs[payment_hash]
q.put_nowait(htlc_log) q.put_nowait(htlc_log)
util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id) util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id)
@ -1405,7 +1402,7 @@ class LNWallet(LNWorker):
failure_msg=failure_message, failure_msg=failure_message,
sender_idx=sender_idx) sender_idx=sender_idx)
q = self.pending_sent_htlcs[payment_hash] q = self.sent_htlcs[payment_hash]
q.put_nowait(htlc_log) q.put_nowait(htlc_log)
util.trigger_callback('htlc_failed', payment_hash, chan.channel_id) util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)

8
electrum/tests/test_lnpeer.py

@ -132,8 +132,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
# used in tests # used in tests
self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle = asyncio.Event()
self.enable_htlc_settle.set() self.enable_htlc_settle.set()
self.pending_htlcs = defaultdict(set) self.received_htlcs = defaultdict(set)
self.pending_sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs = defaultdict(asyncio.Queue)
self.htlc_routes = defaultdict(list) self.htlc_routes = defaultdict(list)
def get_invoice_status(self, key): def get_invoice_status(self, key):
@ -518,9 +518,9 @@ class TestPeer(ElectrumTestCase):
p1.maybe_send_commitment(alice_channel) p1.maybe_send_commitment(alice_channel)
p2.maybe_send_commitment(bob_channel) p2.maybe_send_commitment(bob_channel)
htlc_log1 = await w1.pending_sent_htlcs[lnaddr2.paymenthash].get() htlc_log1 = await w1.sent_htlcs[lnaddr2.paymenthash].get()
assert htlc_log1.success assert htlc_log1.success
htlc_log2 = await w2.pending_sent_htlcs[lnaddr1.paymenthash].get() htlc_log2 = await w2.sent_htlcs[lnaddr1.paymenthash].get()
assert htlc_log2.success assert htlc_log2.success
raise PaymentDone() raise PaymentDone()

Loading…
Cancel
Save