Browse Source

Merge pull request #7292 from bitromortac/2105-inflight-htlcs

lnrouter: add inflight htlcs to liquidity hints
patch-4
ghost43 4 years ago
committed by GitHub
parent
commit
8abbcbff5a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 53
      electrum/lnrouter.py
  2. 102
      electrum/lnworker.py
  3. 39
      electrum/tests/test_lnpeer.py
  4. 58
      electrum/tests/test_lnrouter.py

53
electrum/lnrouter.py

@ -183,6 +183,8 @@ class LiquidityHint:
self._cannot_send_backward = None
self.blacklist_timestamp = 0
self.hint_timestamp = 0
self._inflight_htlcs_forward = 0
self._inflight_htlcs_backward = 0
def is_hint_invalid(self) -> bool:
now = int(time.time())
@ -273,10 +275,28 @@ class LiquidityHint:
else:
self.cannot_send_backward = amount
def num_inflight_htlcs(self, is_forward_direction: bool) -> int:
if is_forward_direction:
return self._inflight_htlcs_forward
else:
return self._inflight_htlcs_backward
def add_htlc(self, is_forward_direction: bool):
if is_forward_direction:
self._inflight_htlcs_forward += 1
else:
self._inflight_htlcs_backward += 1
def remove_htlc(self, is_forward_direction: bool):
if is_forward_direction:
self._inflight_htlcs_forward = max(0, self._inflight_htlcs_forward - 1)
else:
self._inflight_htlcs_backward = max(0, self._inflight_htlcs_forward - 1)
def __repr__(self):
is_blacklisted = False if not self.blacklist_timestamp else int(time.time()) - self.blacklist_timestamp < BLACKLIST_DURATION
return f"forward: can send: {self._can_send_forward} msat, cannot send: {self._cannot_send_forward} msat, \n" \
f"backward: can send: {self._can_send_backward} msat, cannot send: {self._cannot_send_backward} msat, \n" \
return f"forward: can send: {self._can_send_forward} msat, cannot send: {self._cannot_send_forward} msat, htlcs: {self._inflight_htlcs_forward}\n" \
f"backward: can send: {self._can_send_backward} msat, cannot send: {self._cannot_send_backward} msat, htlcs: {self._inflight_htlcs_backward}\n" \
f"blacklisted: {is_blacklisted}"
@ -288,15 +308,13 @@ class LiquidityHintMgr:
algorithm that favors channels which can route payments and penalizes
channels that cannot.
"""
# TODO: incorporate in-flight htlcs
# TODO: use timestamps for can/not_send to make them None after some time?
# TODO: hints based on node pairs only (shadow channels, non-strict forwarding)?
def __init__(self):
self.lock = RLock()
self._liquidity_hints: Dict[ShortChannelID, LiquidityHint] = {}
@with_lock
def get_hint(self, channel_id: ShortChannelID):
def get_hint(self, channel_id: ShortChannelID) -> LiquidityHint:
hint = self._liquidity_hints.get(channel_id)
if not hint:
hint = LiquidityHint()
@ -313,6 +331,16 @@ class LiquidityHintMgr:
hint = self.get_hint(channel_id)
hint.update_cannot_send(node_from < node_to, amount)
@with_lock
def add_htlc(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID):
hint = self.get_hint(channel_id)
hint.add_htlc(node_from < node_to)
@with_lock
def remove_htlc(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID):
hint = self.get_hint(channel_id)
hint.remove_htlc(node_from < node_to)
def penalty(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID, amount: int) -> float:
"""Gives a penalty when sending from node1 to node2 over channel_id with an
amount in units of millisatoshi.
@ -337,16 +365,19 @@ class LiquidityHintMgr:
# we only evaluate hints here, so use dict get (to not create many hints with self.get_hint)
hint = self._liquidity_hints.get(channel_id)
if not hint:
can_send, cannot_send = None, None
can_send, cannot_send, num_inflight_htlcs = None, None, 0
else:
can_send = hint.can_send(node_from < node_to)
cannot_send = hint.cannot_send(node_from < node_to)
num_inflight_htlcs = hint.num_inflight_htlcs(node_from < node_to)
if cannot_send is not None and amount >= cannot_send:
return inf
if can_send is not None and amount <= can_send:
return 0
return fee_for_edge_msat(amount, DEFAULT_PENALTY_BASE_MSAT, DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH)
success_fee = fee_for_edge_msat(amount, DEFAULT_PENALTY_BASE_MSAT, DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH)
inflight_htlc_fee = num_inflight_htlcs * success_fee
return success_fee + inflight_htlc_fee
@with_lock
def add_to_blacklist(self, channel_id: ShortChannelID):
@ -403,6 +434,14 @@ class LNPathFinder(Logger):
self.liquidity_hints.update_cannot_send(r.start_node, r.end_node, r.short_channel_id, amount_msat)
break
def update_inflight_htlcs(self, route: LNPaymentRoute, add_htlcs: bool):
self.logger.info(f"{'Adding' if add_htlcs else 'Removing'} inflight htlcs to graph (liquidity hints).")
for r in route:
if add_htlcs:
self.liquidity_hints.add_htlc(r.start_node, r.end_node, r.short_channel_id)
else:
self.liquidity_hints.remove_htlc(r.start_node, r.end_node, r.short_channel_id)
def _edge_cost(
self,
*,

102
electrum/lnworker.py

@ -8,7 +8,7 @@ from decimal import Decimal
import random
import time
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
NamedTuple, Union, Mapping, Any, Iterable)
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator)
import threading
import socket
import aiohttp
@ -1073,20 +1073,6 @@ class LNWallet(LNWorker):
if chan.short_channel_id == short_channel_id:
return chan
def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
return self.create_routes_for_payment(
amount_msat=amount_msat,
final_total_msat=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
r_tags=decoded_invoice.get_routing_info('r'),
invoice_features=decoded_invoice.get_features(),
trampoline_fee_level=0,
use_two_trampolines=False,
payment_hash=decoded_invoice.paymenthash,
payment_secret=decoded_invoice.payment_secret,
full_path=full_path)
@log_exceptions
async def pay_invoice(
self, invoice: str, *,
@ -1173,8 +1159,7 @@ class LNWallet(LNWorker):
# 1. create a set of routes for remaining amount.
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
# graph updates might occur during the computation
routes = await run_in_thread(partial(
self.create_routes_for_payment,
routes = self.create_routes_for_payment(
amount_msat=amount_to_send,
final_total_msat=amount_to_pay,
invoice_pubkey=node_pubkey,
@ -1186,9 +1171,10 @@ class LNWallet(LNWorker):
payment_secret=payment_secret,
trampoline_fee_level=trampoline_fee_level,
use_two_trampolines=use_two_trampolines,
fwd_trampoline_onion=fwd_trampoline_onion))
fwd_trampoline_onion=fwd_trampoline_onion
)
# 2. send htlcs
for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes:
async for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes:
amount_inflight += amount_receiver_msat
if amount_inflight > amount_to_pay: # safety belts
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
@ -1210,12 +1196,14 @@ class LNWallet(LNWorker):
raise Exception(f"amount_inflight={amount_inflight} < 0")
log.append(htlc_log)
if htlc_log.success:
# TODO: report every route to liquidity hints for mpp
# even in the case of success, we report channels of the
# route as being able to send the same amount in the future,
# as we assume to not know the capacity
if self.network.path_finder:
# TODO: report every route to liquidity hints for mpp
# in the case of success, we report channels of the
# route as being able to send the same amount in the future,
# as we assume to not know the capacity
self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat)
# remove inflight htlcs from liquidity hints
self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
return
# htlc failed
if len(log) >= attempts:
@ -1282,6 +1270,9 @@ class LNWallet(LNWorker):
amount_sent, amount_failed = self.sent_buckets[payment_secret]
amount_sent += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed
if self.network.path_finder:
# add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True)
util.trigger_callback('htlc_added', chan, htlc, SENT)
def handle_error_code_from_failed_htlc(
@ -1291,6 +1282,13 @@ class LNWallet(LNWorker):
sender_idx: int,
failure_msg: OnionRoutingFailure,
amount: int) -> None:
assert self.channel_db # cannot be in trampoline mode
assert self.network.path_finder
# remove inflight htlcs from liquidity hints
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False)
code, data = failure_msg.code, failure_msg.data
# TODO can we use lnmsg.OnionWireSerializer here?
# TODO update onion_wire.csv
@ -1432,8 +1430,7 @@ class LNWallet(LNWorker):
else:
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
@profiler
def create_routes_for_payment(
async def create_routes_for_payment(
self, *,
amount_msat: int, # part of payment amount we want routes for now
final_total_msat: int, # total payment amount final receiver will get
@ -1446,7 +1443,7 @@ class LNWallet(LNWorker):
trampoline_fee_level: int,
use_two_trampolines: bool,
fwd_trampoline_onion = None,
full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
full_path: LNPaymentPath = None) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]:
"""Creates multiple routes for splitting a payment over the available
private channels.
@ -1502,20 +1499,24 @@ class LNWallet(LNWorker):
cltv_expiry_delta=0,
node_features=trampoline_features)
]
routes = [(route, amount_with_fees, trampoline_total_msat, amount_msat, cltv_delta, trampoline_payment_secret, trampoline_onion)]
yield route, amount_with_fees, trampoline_total_msat, amount_msat, cltv_delta, trampoline_payment_secret, trampoline_onion
break
else:
raise NoPathFound()
else:
route = self.create_route_for_payment(
amount_msat=amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
channels=active_channels,
full_path=full_path)
routes = [(route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion)]
route = await run_in_thread(
partial(
self.create_route_for_payment,
amount_msat=amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
channels=active_channels,
full_path=full_path
)
)
yield route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion
except NoPathFound:
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
raise
@ -1532,7 +1533,6 @@ class LNWallet(LNWorker):
for s in split_configurations:
self.logger.info(f"trying split configuration: {s[0].values()} rating: {s[1]}")
routes = []
try:
if not self.channel_db:
buckets = defaultdict(list)
@ -1577,7 +1577,7 @@ class LNWallet(LNWorker):
node_features=trampoline_features)
]
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
routes.append((route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion))
yield route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion
if bucket_fees != 0:
self.logger.info('not enough margin to pay trampoline fee')
raise NoPathFound()
@ -1585,23 +1585,27 @@ class LNWallet(LNWorker):
for (chan_id, _), part_amount_msat in s[0].items():
if part_amount_msat:
channel = self.channels[chan_id]
route = self.create_route_for_payment(
amount_msat=part_amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
channels=[channel],
full_path=None)
routes.append((route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion))
route = await run_in_thread(
partial(
self.create_route_for_payment,
amount_msat=part_amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
channels=[channel],
full_path=None
)
)
yield route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
break
except NoPathFound:
continue
else:
raise NoPathFound()
return routes
@profiler
def create_route_for_payment(
self, *,
amount_msat: int,
@ -1610,7 +1614,7 @@ class LNWallet(LNWorker):
r_tags,
invoice_features: int,
channels: List[Channel],
full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
full_path: Optional[LNPaymentPath]) -> LNPaymentRoute:
scid_to_my_channels = {
chan.short_channel_id: chan for chan in channels

39
electrum/tests/test_lnpeer.py

@ -192,6 +192,20 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.channel_db.stop()
await self.channel_db.stopped_event.wait()
async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
return [r async for r in self.create_routes_for_payment(
amount_msat=amount_msat,
final_total_msat=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
r_tags=decoded_invoice.get_routing_info('r'),
invoice_features=decoded_invoice.get_features(),
trampoline_fee_level=0,
use_two_trampolines=False,
payment_hash=decoded_invoice.paymenthash,
payment_secret=decoded_invoice.payment_secret,
full_path=full_path)]
get_payments = LNWallet.get_payments
get_payment_info = LNWallet.get_payment_info
save_payment_info = LNWallet.save_payment_info
@ -206,7 +220,6 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
get_preimage = LNWallet.get_preimage
create_route_for_payment = LNWallet.create_route_for_payment
create_routes_for_payment = LNWallet.create_routes_for_payment
create_routes_from_invoice = LNWallet.create_routes_from_invoice
_check_invoice = staticmethod(LNWallet._check_invoice)
pay_to_route = LNWallet.pay_to_route
pay_to_node = LNWallet.pay_to_node
@ -598,7 +611,7 @@ class TestPeer(TestCaseForTestnet):
q2 = w2.sent_htlcs[lnaddr1.paymenthash]
# alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None
route1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0][0]
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0]
amount_msat = lnaddr2.get_amount_msat()
await w1.pay_to_route(
route=route1,
@ -612,7 +625,7 @@ class TestPeer(TestCaseForTestnet):
p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None
route2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0][0]
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0]
amount_msat = lnaddr1.get_amount_msat()
await w2.pay_to_route(
route=route2,
@ -982,14 +995,14 @@ class TestPeer(TestCaseForTestnet):
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
# alice sends htlc
route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
htlc = p1.pay(route=route,
chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(),
total_msat=lnaddr.get_amount_msat(),
payment_hash=lnaddr.paymenthash,
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
payment_secret=lnaddr.payment_secret)
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
p1.pay(route=route,
chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(),
total_msat=lnaddr.get_amount_msat(),
payment_hash=lnaddr.paymenthash,
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
payment_secret=lnaddr.payment_secret)
# alice closes
await p1.close_channel(alice_channel.channel_id)
gath.cancel()
@ -1078,7 +1091,7 @@ class TestPeer(TestCaseForTestnet):
lnaddr, pay_req = run(self.prepare_invoice(w2))
lnaddr = w1._check_invoice(pay_req)
route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
route, amount_msat = run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
assert amount_msat == lnaddr.get_amount_msat()
run(w1.force_close_channel(alice_channel.channel_id))
@ -1086,7 +1099,7 @@ class TestPeer(TestCaseForTestnet):
assert q1.qsize() == 1
with self.assertRaises(NoPathFound) as e:
w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)
run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))
peer = w1.peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the

58
electrum/tests/test_lnrouter.py

@ -28,12 +28,18 @@ def node(character: str) -> bytes:
class Test_LNRouter(TestCaseForTestnet):
cdb = None
def setUp(self):
super().setUp()
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
self.config = SimpleConfig({'electrum_path': self.electrum_path})
def tearDown(self):
# if the test called prepare_graph(), channeldb needs to be cleaned up
if self.cdb:
self.cdb.stop()
asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result()
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
self._loop_thread.join(timeout=1)
super().tearDown()
@ -151,10 +157,7 @@ class Test_LNRouter(TestCaseForTestnet):
self.assertEqual(node('b'), route[0].node_id)
self.assertEqual(channel(3), route[0].short_channel_id)
self.cdb.stop()
asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result()
def test_find_path_liquidity_hints_failure(self):
def test_find_path_liquidity_hints(self):
self.prepare_graph()
amount_to_send = 100000
@ -197,7 +200,7 @@ class Test_LNRouter(TestCaseForTestnet):
assume success over channel 4, D -> C
A -3-> B |-2-> E
A -6-> D |-5-> E
A -6-> D -4-> C -7-> E <= chosen path
A -6-> D -4-> C -7-> E <= smaller penalty: chosen path
A -3-> B -1-> C -7-> E
A -6-> D -4-> C -1-> B |-2-> E
A -3-> B -1-> C -4-> D |-5-> E
@ -211,8 +214,43 @@ class Test_LNRouter(TestCaseForTestnet):
self.assertEqual(channel(4), path[1].short_channel_id)
self.assertEqual(channel(7), path[2].short_channel_id)
self.cdb.stop()
asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result()
def test_find_path_liquidity_hints_inflight_htlcs(self):
self.prepare_graph()
amount_to_send = 100000
"""
add inflight htlc to channel 2, B -> E
A -3-> B -2(1)-> E
A -6-> D -5-> E <= chosen path
A -6-> D -4-> C -7-> E
A -3-> B -1-> C -7-> E
A -6-> D -4-> C -1-> B -2-> E
A -3-> B -1-> C -4-> D -5-> E
"""
self.path_finder.liquidity_hints.add_htlc(node('b'), node('e'), channel(2))
path = self.path_finder.find_path_for_payment(
nodeA=node('a'),
nodeB=node('e'),
invoice_amount_msat=amount_to_send)
self.assertEqual(channel(6), path[0].short_channel_id)
self.assertEqual(channel(5), path[1].short_channel_id)
"""
remove inflight htlc from channel 2, B -> E
A -3-> B -2(0)-> E <= chosen path
A -6-> D -5-> E
A -6-> D -4-> C -7-> E
A -3-> B -1-> C -7-> E
A -6-> D -4-> C -1-> B -2-> E
A -3-> B -1-> C -4-> D -5-> E
"""
self.path_finder.liquidity_hints.remove_htlc(node('b'), node('e'), channel(2))
path = self.path_finder.find_path_for_payment(
nodeA=node('a'),
nodeB=node('e'),
invoice_amount_msat=amount_to_send)
self.assertEqual(channel(3), path[0].short_channel_id)
self.assertEqual(channel(2), path[1].short_channel_id)
def test_liquidity_hints(self):
liquidity_hints = LiquidityHintMgr()
@ -251,6 +289,12 @@ class Test_LNRouter(TestCaseForTestnet):
self.assertEqual(3_000_000, hint.can_send(node_from < node_to))
self.assertEqual(None, hint.cannot_send(node_from < node_to))
# test inflight htlc
liquidity_hints.reset_liquidity_hints()
liquidity_hints.add_htlc(node_from, node_to, channel_id)
liquidity_hints.get_hint(channel_id)
# we have got 600 (attempt) + 600 (inflight) penalty
self.assertEqual(1200, liquidity_hints.penalty(node_from, node_to, channel_id, 1_000_000))
@needs_test_with_all_chacha20_implementations
def test_new_onion_packet_legacy(self):

Loading…
Cancel
Save