Browse Source

lnworker: make calc_routing_hints_for_invoice and create_invoice non-async

patch-4
ThomasV 3 years ago
parent
commit
cb39bbbd94
  1. 2
      electrum/commands.py
  2. 19
      electrum/lnworker.py
  3. 41
      electrum/tests/test_lnpeer.py

2
electrum/commands.py

@ -920,7 +920,7 @@ class Commands:
@command('wnl')
async def add_lightning_request(self, amount, memo='', expiration=3600, wallet: Abstract_Wallet = None):
amount_sat = int(satoshis(amount))
key = await wallet.lnworker._add_request_coro(amount_sat, memo, expiration)
key = wallet.lnworker.add_request(amount_sat, memo, expiration)
return wallet.get_formatted_request(key)
@command('w')

19
electrum/lnworker.py

@ -1731,16 +1731,7 @@ class LNWallet(LNWorker):
route[-1].node_features |= invoice_features
return route
def add_request(self, amount_sat, message, expiry) -> str:
coro = self._add_request_coro(amount_sat, message, expiry)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
try:
return fut.result(timeout=5)
except concurrent.futures.TimeoutError:
raise Exception(_("add invoice timed out"))
@log_exceptions
async def create_invoice(
def create_invoice(
self, *,
amount_msat: Optional[int],
message: str,
@ -1749,7 +1740,7 @@ class LNWallet(LNWorker):
) -> Tuple[LnAddr, str]:
timestamp = int(time.time())
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
routing_hints = self.calc_routing_hints_for_invoice(amount_msat)
if not routing_hints:
self.logger.info(
"Warning. No routing hints added to invoice. "
@ -1786,9 +1777,9 @@ class LNWallet(LNWorker):
self.wallet.save_db()
return lnaddr, invoice
async def _add_request_coro(self, amount_sat: Optional[int], message, expiry: int) -> str:
def add_request(self, amount_sat: Optional[int], message, expiry: int) -> str:
amount_msat = amount_sat * 1000 if amount_sat is not None else None
lnaddr, invoice = await self.create_invoice(
lnaddr, invoice = self.create_invoice(
amount_msat=amount_msat,
message=message,
expiry=expiry,
@ -1978,7 +1969,7 @@ class LNWallet(LNWorker):
self.set_invoice_status(key, PR_UNPAID)
util.trigger_callback('payment_failed', self.wallet, key, '')
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)"""
routing_hints = []
channels = list(self.channels.values())

41
electrum/tests/test_lnpeer.py

@ -236,7 +236,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
on_peer_successfully_established = LNWallet.on_peer_successfully_established
get_channel_by_id = LNWallet.get_channel_by_id
channels_for_peer = LNWallet.channels_for_peer
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
calc_routing_hints_for_invoice = LNWallet.calc_routing_hints_for_invoice
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
@ -469,7 +469,7 @@ class TestPeer(TestCaseForTestnet):
return graph
@staticmethod
async def prepare_invoice(
def prepare_invoice(
w2: MockLNWallet, # receiver
*,
amount_msat=100_000_000,
@ -482,7 +482,7 @@ class TestPeer(TestCaseForTestnet):
w2.save_preimage(RHASH, payment_preimage)
w2.save_payment_info(info)
if include_routing_hints:
routing_hints = await w2._calc_routing_hints_for_invoice(amount_msat)
routing_hints = w2.calc_routing_hints_for_invoice(amount_msat)
else:
routing_hints = []
trampoline_hints = []
@ -532,7 +532,7 @@ class TestPeer(TestCaseForTestnet):
alice_channel, bob_channel = create_test_channels(random_seed=random_seed)
alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
lnaddr, pay_req = run(self.prepare_invoice(w2))
lnaddr, pay_req = self.prepare_invoice(w2)
async def pay():
result, log = await w1.pay_invoice(pay_req)
self.assertEqual(result, True)
@ -575,7 +575,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
lnaddr, pay_req = await self.prepare_invoice(w2)
lnaddr, pay_req = self.prepare_invoice(w2)
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
@ -597,8 +597,8 @@ class TestPeer(TestCaseForTestnet):
# prep
_maybe_send_commitment1 = p1.maybe_send_commitment
_maybe_send_commitment2 = p2.maybe_send_commitment
lnaddr2, pay_req2 = await self.prepare_invoice(w2)
lnaddr1, pay_req1 = await self.prepare_invoice(w1)
lnaddr2, pay_req2 = self.prepare_invoice(w2)
lnaddr1, pay_req1 = self.prepare_invoice(w1)
# create the htlc queues now (side-effecting defaultdict)
q1 = w1.sent_htlcs[lnaddr2.paymenthash]
q2 = w2.sent_htlcs[lnaddr1.paymenthash]
@ -670,11 +670,8 @@ class TestPeer(TestCaseForTestnet):
await w1.pay_invoice(pay_req)
async def many_payments():
async with OldTaskGroup() as group:
pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat))
for i in range(num_payments)]
async with OldTaskGroup() as group:
for pay_req_task in pay_reqs_tasks:
lnaddr, pay_req = pay_req_task.result()
for i in range(num_payments):
lnaddr, pay_req = self.prepare_invoice(w2, amount_msat=payment_value_msat)
await group.spawn(single_payment(pay_req))
gath.cancel()
gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
@ -703,7 +700,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone):
run(f())
@ -747,7 +744,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone):
run(f())
@ -771,7 +768,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone):
run(f())
@ -806,7 +803,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
@ -869,7 +866,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], amount_msat=amount_to_pay, include_routing_hints=True)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], amount_msat=amount_to_pay, include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone):
run(f())
@ -900,7 +897,7 @@ class TestPeer(TestCaseForTestnet):
graph.workers['alice'].network.channel_db = None
else:
assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts)
if not bob_forwarding:
@ -960,7 +957,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req))
for is_legacy in (True, False):
@ -1035,7 +1032,7 @@ class TestPeer(TestCaseForTestnet):
graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT
graph.workers['bob'].enable_htlc_forwarding = False # Bob will hold forwarded HTLCs
assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
try:
async with timeout_after(0.5):
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=1)
@ -1104,7 +1101,7 @@ class TestPeer(TestCaseForTestnet):
else:
w2.network.config.set_key('test_shutdown_legacy', True)
w2.enable_htlc_settle = False
lnaddr, pay_req = run(self.prepare_invoice(w2))
lnaddr, pay_req = self.prepare_invoice(w2)
async def pay():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
@ -1234,7 +1231,7 @@ class TestPeer(TestCaseForTestnet):
def test_channel_usage_after_closing(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
lnaddr, pay_req = run(self.prepare_invoice(w2))
lnaddr, pay_req = self.prepare_invoice(w2)
lnaddr = w1._check_invoice(pay_req)
route, amount_msat = run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]

Loading…
Cancel
Save