Browse Source

lnrouter: run Dijkstra in reverse direction

regtest_lnd
SomberNight 7 years ago
parent
commit
a547d997e9
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 73
      electrum/lnrouter.py
  2. 4
      electrum/lnworker.py
  3. 5
      electrum/tests/test_lnrouter.py

73
electrum/lnrouter.py

@ -565,59 +565,64 @@ class LNPathFinder(PrintError):
self.blacklist = set() self.blacklist = set()
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_cltv=False) -> float: payment_amt_msat: int, ignore_costs=False) -> Tuple[float, int]:
"""Heuristic cost of going through a channel.""" """Heuristic cost of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat).
"""
channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo
if channel_info is None: if channel_info is None:
return float('inf') return float('inf'), 0
channel_policy = channel_info.get_policy_for_node(start_node) channel_policy = channel_info.get_policy_for_node(start_node)
if channel_policy is None: return float('inf') if channel_policy is None: return float('inf'), 0
if channel_policy.disabled: return float('inf') if channel_policy.disabled: return float('inf'), 0
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
if payment_amt_msat < channel_policy.htlc_minimum_msat: if payment_amt_msat < channel_policy.htlc_minimum_msat:
return float('inf') # payment amount too little return float('inf'), 0 # payment amount too little
if channel_info.capacity_sat is not None and \ if channel_info.capacity_sat is not None and \
payment_amt_msat // 1000 > channel_info.capacity_sat: payment_amt_msat // 1000 > channel_info.capacity_sat:
return float('inf') # payment amount too large return float('inf'), 0 # payment amount too large
if channel_policy.htlc_maximum_msat is not None and \ if channel_policy.htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy.htlc_maximum_msat: payment_amt_msat > channel_policy.htlc_maximum_msat:
return float('inf') # payment amount too large return float('inf'), 0 # payment amount too large
if not route_edge.is_sane_to_use(payment_amt_msat): if not route_edge.is_sane_to_use(payment_amt_msat):
return float('inf') # thanks but no thanks return float('inf'), 0 # thanks but no thanks
fee_msat = route_edge.fee_for_edge(payment_amt_msat) fee_msat = route_edge.fee_for_edge(payment_amt_msat) if not ignore_costs else 0
# TODO revise # TODO revise
# paying 10 more satoshis ~ waiting one more block # paying 10 more satoshis ~ waiting one more block
fee_cost = fee_msat / 1000 / 10 fee_cost = fee_msat / 1000 / 10
cltv_cost = route_edge.cltv_expiry_delta if not ignore_cltv else 0 cltv_cost = route_edge.cltv_expiry_delta if not ignore_costs else 0
return cltv_cost + fee_cost + 1 return cltv_cost + fee_cost + 1, fee_msat
@profiler @profiler
def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes, def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
amount_msat: int, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]: invoice_amount_msat: int,
"""Return a path between from_node_id and to_node_id. my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
"""Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path. Returns a list of (node_id, short_channel_id) representing a path.
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
i.e. an element reads as, "to get to node_id, travel through short_channel_id" i.e. an element reads as, "to get to node_id, travel through short_channel_id"
""" """
assert type(amount_msat) is int assert type(invoice_amount_msat) is int
if my_channels is None: my_channels = [] if my_channels is None: my_channels = []
unable_channels = set(map(lambda x: x.short_channel_id, filter(lambda x: not x.can_pay(amount_msat), my_channels))) unable_channels = set(map(lambda x: x.short_channel_id,
filter(lambda x: not x.can_pay(invoice_amount_msat), my_channels)))
# TODO find multiple paths??
# FIXME paths cannot be longer than 21 edges (onion packet)... # FIXME paths cannot be longer than 21 edges (onion packet)...
# run Dijkstra # run Dijkstra
# The search is run in the REVERSE direction, from nodeB to nodeA,
# to properly calculate compound routing fees.
distance_from_start = defaultdict(lambda: float('inf')) distance_from_start = defaultdict(lambda: float('inf'))
distance_from_start[from_node_id] = 0 distance_from_start[nodeB] = 0
prev_node = {} prev_node = {}
nodes_to_explore = queue.PriorityQueue() nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, from_node_id)) nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
while nodes_to_explore.qsize() > 0: while nodes_to_explore.qsize() > 0:
dist_to_cur_node, cur_node = nodes_to_explore.get() dist_to_cur_node, amount_msat, cur_node = nodes_to_explore.get()
if cur_node == to_node_id: if cur_node == nodeA:
break break
if dist_to_cur_node != distance_from_start[cur_node]: if dist_to_cur_node != distance_from_start[cur_node]:
# queue.PriorityQueue does not implement decrease_priority, # queue.PriorityQueue does not implement decrease_priority,
@ -628,27 +633,29 @@ class LNPathFinder(PrintError):
if edge_channel_id in self.blacklist or edge_channel_id in unable_channels: if edge_channel_id in self.blacklist or edge_channel_id in unable_channels:
continue continue
channel_info = self.channel_db.get_channel_info(edge_channel_id) channel_info = self.channel_db.get_channel_info(edge_channel_id)
node1, node2 = channel_info.node_id_1, channel_info.node_id_2 neighbour = channel_info.node_id_2 if channel_info.node_id_1 == cur_node else channel_info.node_id_1
neighbour = node2 if node1 == cur_node else node1 ignore_costs = neighbour == nodeA # no fees when using our own channel
ignore_cltv_delta_in_edge_cost = cur_node == from_node_id edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id,
edge_cost = self._edge_cost(edge_channel_id, cur_node, neighbour, amount_msat, start_node=neighbour,
ignore_cltv=ignore_cltv_delta_in_edge_cost) end_node=cur_node,
payment_amt_msat=amount_msat,
ignore_costs=ignore_costs)
alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost
if alt_dist_to_neighbour < distance_from_start[neighbour]: if alt_dist_to_neighbour < distance_from_start[neighbour]:
distance_from_start[neighbour] = alt_dist_to_neighbour distance_from_start[neighbour] = alt_dist_to_neighbour
prev_node[neighbour] = cur_node, edge_channel_id prev_node[neighbour] = cur_node, edge_channel_id
nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, neighbour))
else: else:
return None # no path found return None # no path found
# backtrack from end to start # backtrack from search_end (nodeA) to search_start (nodeB)
cur_node = to_node_id cur_node = nodeA
path = [] path = []
while cur_node != from_node_id: while cur_node != nodeB:
prev_node_id, edge_taken = prev_node[cur_node] prev_node_id, edge_taken = prev_node[cur_node]
path += [(cur_node, edge_taken)] path += [(prev_node_id, edge_taken)]
cur_node = prev_node_id cur_node = prev_node_id
path.reverse()
return path return path
def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]: def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:

4
electrum/lnworker.py

@ -260,14 +260,14 @@ class LNWorker(PrintError):
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat) route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
peer = self.peers[node_id]
with self.lock: with self.lock:
channels = list(self.channels.values()) channels = list(self.channels.values())
for chan in channels: for chan in channels:
if chan.short_channel_id == short_channel_id: if chan.short_channel_id == short_channel_id:
break break
else: else:
raise Exception("ChannelDB returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
peer = self.peers[node_id]
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry()) coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)

5
electrum/tests/test_lnrouter.py

@ -93,6 +93,11 @@ class Test_LNRouter(TestCaseForTestnet):
cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(99999999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True) cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(99999999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(150), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True) cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(150), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
self.assertNotEqual(None, path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)) self.assertNotEqual(None, path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000))
self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'),
(b'\x02cccccccccccccccccccccccccccccccc', b'\x00\x00\x00\x00\x00\x00\x00\x01'),
(b'\x02dddddddddddddddddddddddddddddddd', b'\x00\x00\x00\x00\x00\x00\x00\x04'),
(b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x05')],
path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000))

Loading…
Cancel
Save