|
@ -25,13 +25,13 @@ |
|
|
|
|
|
|
|
|
import queue |
|
|
import queue |
|
|
from collections import defaultdict |
|
|
from collections import defaultdict |
|
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set |
|
|
from typing import Sequence, Tuple, Optional, Dict, TYPE_CHECKING, Set |
|
|
import time |
|
|
import time |
|
|
from threading import RLock |
|
|
from threading import RLock |
|
|
import attr |
|
|
import attr |
|
|
from math import inf |
|
|
from math import inf |
|
|
|
|
|
|
|
|
from .util import bh2u, profiler, with_lock |
|
|
from .util import profiler, with_lock, bh2u |
|
|
from .logging import Logger |
|
|
from .logging import Logger |
|
|
from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures, |
|
|
from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures, |
|
|
NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE) |
|
|
NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE) |
|
@ -511,7 +511,7 @@ class LNPathFinder(Logger): |
|
|
overall_cost = fee_msat + cltv_cost + liquidity_penalty |
|
|
overall_cost = fee_msat + cltv_cost + liquidity_penalty |
|
|
return overall_cost, fee_msat |
|
|
return overall_cost, fee_msat |
|
|
|
|
|
|
|
|
def get_distances( |
|
|
def get_shortest_path_hops( |
|
|
self, |
|
|
self, |
|
|
*, |
|
|
*, |
|
|
nodeA: bytes, |
|
|
nodeA: bytes, |
|
@ -529,22 +529,34 @@ class LNPathFinder(Logger): |
|
|
blacklist = self.liquidity_hints.get_blacklist() |
|
|
blacklist = self.liquidity_hints.get_blacklist() |
|
|
distance_from_start = defaultdict(lambda: float('inf')) |
|
|
distance_from_start = defaultdict(lambda: float('inf')) |
|
|
distance_from_start[nodeB] = 0 |
|
|
distance_from_start[nodeB] = 0 |
|
|
prev_node = {} # type: Dict[bytes, PathEdge] |
|
|
previous_hops = {} # type: Dict[bytes, PathEdge] |
|
|
nodes_to_explore = queue.PriorityQueue() |
|
|
nodes_to_explore = queue.PriorityQueue() |
|
|
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! |
|
|
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! |
|
|
|
|
|
|
|
|
# main loop of search |
|
|
# main loop of search |
|
|
while nodes_to_explore.qsize() > 0: |
|
|
while nodes_to_explore.qsize() > 0: |
|
|
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() |
|
|
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() |
|
|
if edge_endnode == nodeA: |
|
|
if edge_endnode == nodeA and previous_hops: # previous_hops check for circular paths |
|
|
|
|
|
self.logger.info("found a path") |
|
|
break |
|
|
break |
|
|
if dist_to_edge_endnode != distance_from_start[edge_endnode]: |
|
|
if dist_to_edge_endnode != distance_from_start[edge_endnode]: |
|
|
# queue.PriorityQueue does not implement decrease_priority, |
|
|
# queue.PriorityQueue does not implement decrease_priority, |
|
|
# so instead of decreasing priorities, we add items again into the queue. |
|
|
# so instead of decreasing priorities, we add items again into the queue. |
|
|
# so there are duplicates in the queue, that we discard now: |
|
|
# so there are duplicates in the queue, that we discard now: |
|
|
continue |
|
|
continue |
|
|
for edge_channel_id in self.channel_db.get_channels_for_node( |
|
|
|
|
|
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges): |
|
|
if nodeA == nodeB: # we want circular paths |
|
|
|
|
|
if not previous_hops: # in the first node exploration step, we only take receiving channels |
|
|
|
|
|
channels_for_endnode = self.channel_db.get_channels_for_node( |
|
|
|
|
|
edge_endnode, my_channels={}, private_route_edges=private_route_edges) |
|
|
|
|
|
else: # in the next steps, we only take sending channels |
|
|
|
|
|
channels_for_endnode = self.channel_db.get_channels_for_node( |
|
|
|
|
|
edge_endnode, my_channels=my_channels, private_route_edges={}) |
|
|
|
|
|
else: |
|
|
|
|
|
channels_for_endnode = self.channel_db.get_channels_for_node( |
|
|
|
|
|
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges) |
|
|
|
|
|
|
|
|
|
|
|
for edge_channel_id in channels_for_endnode: |
|
|
assert isinstance(edge_channel_id, bytes) |
|
|
assert isinstance(edge_channel_id, bytes) |
|
|
if blacklist and edge_channel_id in blacklist: |
|
|
if blacklist and edge_channel_id in blacklist: |
|
|
continue |
|
|
continue |
|
@ -558,10 +570,6 @@ class LNPathFinder(Logger): |
|
|
if edge_startnode == nodeA: # payment outgoing, on our channel |
|
|
if edge_startnode == nodeA: # payment outgoing, on our channel |
|
|
if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True): |
|
|
if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True): |
|
|
continue |
|
|
continue |
|
|
else: # payment incoming, on our channel. (funny business, cycle weirdness) |
|
|
|
|
|
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode)) |
|
|
|
|
|
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True): |
|
|
|
|
|
continue |
|
|
|
|
|
edge_cost, fee_for_edge_msat = self._edge_cost( |
|
|
edge_cost, fee_for_edge_msat = self._edge_cost( |
|
|
short_channel_id=edge_channel_id, |
|
|
short_channel_id=edge_channel_id, |
|
|
start_node=edge_startnode, |
|
|
start_node=edge_startnode, |
|
@ -574,14 +582,17 @@ class LNPathFinder(Logger): |
|
|
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost |
|
|
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost |
|
|
if alt_dist_to_neighbour < distance_from_start[edge_startnode]: |
|
|
if alt_dist_to_neighbour < distance_from_start[edge_startnode]: |
|
|
distance_from_start[edge_startnode] = alt_dist_to_neighbour |
|
|
distance_from_start[edge_startnode] = alt_dist_to_neighbour |
|
|
prev_node[edge_startnode] = PathEdge( |
|
|
previous_hops[edge_startnode] = PathEdge( |
|
|
start_node=edge_startnode, |
|
|
start_node=edge_startnode, |
|
|
end_node=edge_endnode, |
|
|
end_node=edge_endnode, |
|
|
short_channel_id=ShortChannelID(edge_channel_id)) |
|
|
short_channel_id=ShortChannelID(edge_channel_id)) |
|
|
amount_to_forward_msat = amount_msat + fee_for_edge_msat |
|
|
amount_to_forward_msat = amount_msat + fee_for_edge_msat |
|
|
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) |
|
|
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) |
|
|
|
|
|
# for circular paths, we already explored the end node, but this |
|
|
return prev_node |
|
|
# is also our start node, so set it to unexplored |
|
|
|
|
|
if edge_endnode == nodeB and nodeA == nodeB: |
|
|
|
|
|
distance_from_start[edge_endnode] = float('inf') |
|
|
|
|
|
return previous_hops |
|
|
|
|
|
|
|
|
@profiler |
|
|
@profiler |
|
|
def find_path_for_payment( |
|
|
def find_path_for_payment( |
|
@ -600,22 +611,22 @@ class LNPathFinder(Logger): |
|
|
if my_channels is None: |
|
|
if my_channels is None: |
|
|
my_channels = {} |
|
|
my_channels = {} |
|
|
|
|
|
|
|
|
prev_node = self.get_distances( |
|
|
previous_hops = self.get_shortest_path_hops( |
|
|
nodeA=nodeA, |
|
|
nodeA=nodeA, |
|
|
nodeB=nodeB, |
|
|
nodeB=nodeB, |
|
|
invoice_amount_msat=invoice_amount_msat, |
|
|
invoice_amount_msat=invoice_amount_msat, |
|
|
my_channels=my_channels, |
|
|
my_channels=my_channels, |
|
|
private_route_edges=private_route_edges) |
|
|
private_route_edges=private_route_edges) |
|
|
|
|
|
|
|
|
if nodeA not in prev_node: |
|
|
if nodeA not in previous_hops: |
|
|
return None # no path found |
|
|
return None # no path found |
|
|
|
|
|
|
|
|
# backtrack from search_end (nodeA) to search_start (nodeB) |
|
|
# backtrack from search_end (nodeA) to search_start (nodeB) |
|
|
# FIXME paths cannot be longer than 20 edges (onion packet)... |
|
|
# FIXME paths cannot be longer than 20 edges (onion packet)... |
|
|
edge_startnode = nodeA |
|
|
edge_startnode = nodeA |
|
|
path = [] |
|
|
path = [] |
|
|
while edge_startnode != nodeB: |
|
|
while edge_startnode != nodeB or not path: # second condition for circular paths |
|
|
edge = prev_node[edge_startnode] |
|
|
edge = previous_hops[edge_startnode] |
|
|
path += [edge] |
|
|
path += [edge] |
|
|
edge_startnode = edge.node_id |
|
|
edge_startnode = edge.node_id |
|
|
return path |
|
|
return path |
|
|