Browse Source

lnworker: make calc_routing_hints_for_invoice and create_invoice non-async

patch-4
ThomasV 3 years ago
parent
commit
cb39bbbd94
  1. 2
      electrum/commands.py
  2. 19
      electrum/lnworker.py
  3. 41
      electrum/tests/test_lnpeer.py

2
electrum/commands.py

@ -920,7 +920,7 @@ class Commands:
@command('wnl') @command('wnl')
async def add_lightning_request(self, amount, memo='', expiration=3600, wallet: Abstract_Wallet = None): async def add_lightning_request(self, amount, memo='', expiration=3600, wallet: Abstract_Wallet = None):
amount_sat = int(satoshis(amount)) amount_sat = int(satoshis(amount))
key = await wallet.lnworker._add_request_coro(amount_sat, memo, expiration) key = wallet.lnworker.add_request(amount_sat, memo, expiration)
return wallet.get_formatted_request(key) return wallet.get_formatted_request(key)
@command('w') @command('w')

19
electrum/lnworker.py

@ -1731,16 +1731,7 @@ class LNWallet(LNWorker):
route[-1].node_features |= invoice_features route[-1].node_features |= invoice_features
return route return route
def add_request(self, amount_sat, message, expiry) -> str: def create_invoice(
coro = self._add_request_coro(amount_sat, message, expiry)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
try:
return fut.result(timeout=5)
except concurrent.futures.TimeoutError:
raise Exception(_("add invoice timed out"))
@log_exceptions
async def create_invoice(
self, *, self, *,
amount_msat: Optional[int], amount_msat: Optional[int],
message: str, message: str,
@ -1749,7 +1740,7 @@ class LNWallet(LNWorker):
) -> Tuple[LnAddr, str]: ) -> Tuple[LnAddr, str]:
timestamp = int(time.time()) timestamp = int(time.time())
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat) routing_hints = self.calc_routing_hints_for_invoice(amount_msat)
if not routing_hints: if not routing_hints:
self.logger.info( self.logger.info(
"Warning. No routing hints added to invoice. " "Warning. No routing hints added to invoice. "
@ -1786,9 +1777,9 @@ class LNWallet(LNWorker):
self.wallet.save_db() self.wallet.save_db()
return lnaddr, invoice return lnaddr, invoice
async def _add_request_coro(self, amount_sat: Optional[int], message, expiry: int) -> str: def add_request(self, amount_sat: Optional[int], message, expiry: int) -> str:
amount_msat = amount_sat * 1000 if amount_sat is not None else None amount_msat = amount_sat * 1000 if amount_sat is not None else None
lnaddr, invoice = await self.create_invoice( lnaddr, invoice = self.create_invoice(
amount_msat=amount_msat, amount_msat=amount_msat,
message=message, message=message,
expiry=expiry, expiry=expiry,
@ -1978,7 +1969,7 @@ class LNWallet(LNWorker):
self.set_invoice_status(key, PR_UNPAID) self.set_invoice_status(key, PR_UNPAID)
util.trigger_callback('payment_failed', self.wallet, key, '') util.trigger_callback('payment_failed', self.wallet, key, '')
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]): def calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)""" """calculate routing hints (BOLT-11 'r' field)"""
routing_hints = [] routing_hints = []
channels = list(self.channels.values()) channels = list(self.channels.values())

41
electrum/tests/test_lnpeer.py

@ -236,7 +236,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
on_peer_successfully_established = LNWallet.on_peer_successfully_established on_peer_successfully_established = LNWallet.on_peer_successfully_established
get_channel_by_id = LNWallet.get_channel_by_id get_channel_by_id = LNWallet.get_channel_by_id
channels_for_peer = LNWallet.channels_for_peer channels_for_peer = LNWallet.channels_for_peer
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice calc_routing_hints_for_invoice = LNWallet.calc_routing_hints_for_invoice
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
@ -469,7 +469,7 @@ class TestPeer(TestCaseForTestnet):
return graph return graph
@staticmethod @staticmethod
async def prepare_invoice( def prepare_invoice(
w2: MockLNWallet, # receiver w2: MockLNWallet, # receiver
*, *,
amount_msat=100_000_000, amount_msat=100_000_000,
@ -482,7 +482,7 @@ class TestPeer(TestCaseForTestnet):
w2.save_preimage(RHASH, payment_preimage) w2.save_preimage(RHASH, payment_preimage)
w2.save_payment_info(info) w2.save_payment_info(info)
if include_routing_hints: if include_routing_hints:
routing_hints = await w2._calc_routing_hints_for_invoice(amount_msat) routing_hints = w2.calc_routing_hints_for_invoice(amount_msat)
else: else:
routing_hints = [] routing_hints = []
trampoline_hints = [] trampoline_hints = []
@ -532,7 +532,7 @@ class TestPeer(TestCaseForTestnet):
alice_channel, bob_channel = create_test_channels(random_seed=random_seed) alice_channel, bob_channel = create_test_channels(random_seed=random_seed)
alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
lnaddr, pay_req = run(self.prepare_invoice(w2)) lnaddr, pay_req = self.prepare_invoice(w2)
async def pay(): async def pay():
result, log = await w1.pay_invoice(pay_req) result, log = await w1.pay_invoice(pay_req)
self.assertEqual(result, True) self.assertEqual(result, True)
@ -575,7 +575,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(p2._message_loop()) await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch()) await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
lnaddr, pay_req = await self.prepare_invoice(w2) lnaddr, pay_req = self.prepare_invoice(w2)
invoice_features = lnaddr.get_features() invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
@ -597,8 +597,8 @@ class TestPeer(TestCaseForTestnet):
# prep # prep
_maybe_send_commitment1 = p1.maybe_send_commitment _maybe_send_commitment1 = p1.maybe_send_commitment
_maybe_send_commitment2 = p2.maybe_send_commitment _maybe_send_commitment2 = p2.maybe_send_commitment
lnaddr2, pay_req2 = await self.prepare_invoice(w2) lnaddr2, pay_req2 = self.prepare_invoice(w2)
lnaddr1, pay_req1 = await self.prepare_invoice(w1) lnaddr1, pay_req1 = self.prepare_invoice(w1)
# create the htlc queues now (side-effecting defaultdict) # create the htlc queues now (side-effecting defaultdict)
q1 = w1.sent_htlcs[lnaddr2.paymenthash] q1 = w1.sent_htlcs[lnaddr2.paymenthash]
q2 = w2.sent_htlcs[lnaddr1.paymenthash] q2 = w2.sent_htlcs[lnaddr1.paymenthash]
@ -670,11 +670,8 @@ class TestPeer(TestCaseForTestnet):
await w1.pay_invoice(pay_req) await w1.pay_invoice(pay_req)
async def many_payments(): async def many_payments():
async with OldTaskGroup() as group: async with OldTaskGroup() as group:
pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat)) for i in range(num_payments):
for i in range(num_payments)] lnaddr, pay_req = self.prepare_invoice(w2, amount_msat=payment_value_msat)
async with OldTaskGroup() as group:
for pay_req_task in pay_reqs_tasks:
lnaddr, pay_req = pay_req_task.result()
await group.spawn(single_payment(pay_req)) await group.spawn(single_payment(pay_req))
gath.cancel() gath.cancel()
gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
@ -703,7 +700,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -747,7 +744,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(pay_req)) await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -771,7 +768,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -806,7 +803,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
invoice_features = lnaddr.get_features() invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
@ -869,7 +866,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], amount_msat=amount_to_pay, include_routing_hints=True) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], amount_msat=amount_to_pay, include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@ -900,7 +897,7 @@ class TestPeer(TestCaseForTestnet):
graph.workers['alice'].network.channel_db = None graph.workers['alice'].network.channel_db = None
else: else:
assert graph.workers['alice'].network.channel_db is not None assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts)
if not bob_forwarding: if not bob_forwarding:
@ -960,7 +957,7 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(peer._message_loop()) await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
for is_legacy in (True, False): for is_legacy in (True, False):
@ -1035,7 +1032,7 @@ class TestPeer(TestCaseForTestnet):
graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT
graph.workers['bob'].enable_htlc_forwarding = False # Bob will hold forwarded HTLCs graph.workers['bob'].enable_htlc_forwarding = False # Bob will hold forwarded HTLCs
assert graph.workers['alice'].network.channel_db is not None assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = await self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
try: try:
async with timeout_after(0.5): async with timeout_after(0.5):
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=1) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=1)
@ -1104,7 +1101,7 @@ class TestPeer(TestCaseForTestnet):
else: else:
w2.network.config.set_key('test_shutdown_legacy', True) w2.network.config.set_key('test_shutdown_legacy', True)
w2.enable_htlc_settle = False w2.enable_htlc_settle = False
lnaddr, pay_req = run(self.prepare_invoice(w2)) lnaddr, pay_req = self.prepare_invoice(w2)
async def pay(): async def pay():
await asyncio.wait_for(p1.initialized, 1) await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1) await asyncio.wait_for(p2.initialized, 1)
@ -1234,7 +1231,7 @@ class TestPeer(TestCaseForTestnet):
def test_channel_usage_after_closing(self): def test_channel_usage_after_closing(self):
alice_channel, bob_channel = create_test_channels() alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
lnaddr, pay_req = run(self.prepare_invoice(w2)) lnaddr, pay_req = self.prepare_invoice(w2)
lnaddr = w1._check_invoice(pay_req) lnaddr = w1._check_invoice(pay_req)
route, amount_msat = run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] route, amount_msat = run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]

Loading…
Cancel
Save