@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
@attr . s ( slots = True )
class PathEdge :
""" if you travel through short_channel_id, you will reach node_id """
node_id = attr . ib ( type = bytes , kw_only = True , repr = lambda val : val . hex ( ) )
start_node = attr . ib ( type = bytes , kw_only = True , repr = lambda val : val . hex ( ) )
end_ node = attr . ib ( type = bytes , kw_only = True , repr = lambda val : val . hex ( ) )
short_channel_id = attr . ib ( type = ShortChannelID , kw_only = True , repr = lambda val : str ( val ) )
@property
def node_id ( self ) - > bytes :
# legacy compat # TODO rm
return self . end_node
@attr . s
class RouteEdge ( PathEdge ) :
@ -73,12 +77,21 @@ class RouteEdge(PathEdge):
fee_proportional_millionths = self . fee_proportional_millionths )
@classmethod
def from_channel_policy ( cls , channel_policy : ' Policy ' ,
short_channel_id : bytes , end_node : bytes , * ,
node_info : Optional [ NodeInfo ] ) - > ' RouteEdge ' :
def from_channel_policy (
cls ,
* ,
channel_policy : ' Policy ' ,
short_channel_id : bytes ,
start_node : bytes ,
end_node : bytes ,
node_info : Optional [ NodeInfo ] , # for end_node
) - > ' RouteEdge ' :
assert isinstance ( short_channel_id , bytes )
assert type ( start_node ) is bytes
assert type ( end_node ) is bytes
return RouteEdge ( node_id = end_node ,
return RouteEdge (
start_node = start_node ,
end_node = end_node ,
short_channel_id = ShortChannelID . normalize ( short_channel_id ) ,
fee_base_msat = channel_policy . fee_base_msat ,
fee_proportional_millionths = channel_policy . fee_proportional_millionths ,
@ -155,21 +168,37 @@ class LNPathFinder(Logger):
Logger . __init__ ( self )
self . channel_db = channel_db
def _edge_cost ( self , short_channel_id : bytes , start_node : bytes , end_node : bytes ,
payment_amt_msat : int , ignore_costs = False , is_mine = False , * ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ) - > Tuple [ float , int ] :
def _edge_cost (
self ,
* ,
short_channel_id : bytes ,
start_node : bytes ,
end_node : bytes ,
payment_amt_msat : int ,
ignore_costs = False ,
is_mine = False ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
private_route_edges : Dict [ ShortChannelID , RouteEdge ] = None ,
) - > Tuple [ float , int ] :
""" Heuristic cost (distance metric) of going through a channel.
Returns ( heuristic_cost , fee_for_edge_msat ) .
"""
channel_info = self . channel_db . get_channel_info ( short_channel_id , my_channels = my_channels )
if private_route_edges is None :
private_route_edges = { }
channel_info = self . channel_db . get_channel_info (
short_channel_id , my_channels = my_channels , private_route_edges = private_route_edges )
if channel_info is None :
return float ( ' inf ' ) , 0
channel_policy = self . channel_db . get_policy_for_node ( short_channel_id , start_node , my_channels = my_channels )
channel_policy = self . channel_db . get_policy_for_node (
short_channel_id , start_node , my_channels = my_channels , private_route_edges = private_route_edges )
if channel_policy is None :
return float ( ' inf ' ) , 0
# channels that did not publish both policies often return temporary channel failure
if self . channel_db . get_policy_for_node ( short_channel_id , end_node , my_channels = my_channels ) is None \
and not is_mine :
channel_policy_backwards = self . channel_db . get_policy_for_node (
short_channel_id , end_node , my_channels = my_channels , private_route_edges = private_route_edges )
if ( channel_policy_backwards is None
and not is_mine
and short_channel_id not in private_route_edges ) :
return float ( ' inf ' ) , 0
if channel_policy . is_disabled ( ) :
return float ( ' inf ' ) , 0
@ -181,8 +210,14 @@ class LNPathFinder(Logger):
if channel_policy . htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy . htlc_maximum_msat :
return float ( ' inf ' ) , 0 # payment amount too large
route_edge = private_route_edges . get ( short_channel_id , None )
if route_edge is None :
node_info = self . channel_db . get_node_info_for_node_id ( node_id = end_node )
route_edge = RouteEdge . from_channel_policy ( channel_policy , short_channel_id , end_node ,
route_edge = RouteEdge . from_channel_policy (
channel_policy = channel_policy ,
short_channel_id = short_channel_id ,
start_node = start_node ,
end_node = end_node ,
node_info = node_info )
if not route_edge . is_sane_to_use ( payment_amt_msat ) :
return float ( ' inf ' ) , 0 # thanks but no thanks
@ -201,9 +236,16 @@ class LNPathFinder(Logger):
overall_cost = base_cost + fee_msat + cltv_cost
return overall_cost , fee_msat
def get_distances ( self , nodeA : bytes , nodeB : bytes , invoice_amount_msat : int , * ,
def get_distances (
self ,
* ,
nodeA : bytes ,
nodeB : bytes ,
invoice_amount_msat : int ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
blacklist : Set [ ShortChannelID ] = None ) - > Dict [ bytes , PathEdge ] :
blacklist : Set [ ShortChannelID ] = None ,
private_route_edges : Dict [ ShortChannelID , RouteEdge ] = None ,
) - > Dict [ bytes , PathEdge ] :
# note: we don't lock self.channel_db, so while the path finding runs,
# the underlying graph could potentially change... (not good but maybe ~OK?)
@ -226,11 +268,13 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self . channel_db . get_channels_for_node ( edge_endnode , my_channels = my_channels ) :
for edge_channel_id in self . channel_db . get_channels_for_node (
edge_endnode , my_channels = my_channels , private_route_edges = private_route_edges ) :
assert isinstance ( edge_channel_id , bytes )
if blacklist and edge_channel_id in blacklist :
continue
channel_info = self . channel_db . get_channel_info ( edge_channel_id , my_channels = my_channels )
channel_info = self . channel_db . get_channel_info (
edge_channel_id , my_channels = my_channels , private_route_edges = private_route_edges )
edge_startnode = channel_info . node2_id if channel_info . node1_id == edge_endnode else channel_info . node1_id
is_mine = edge_channel_id in my_channels
if is_mine :
@ -242,17 +286,20 @@ class LNPathFinder(Logger):
if not my_channels [ edge_channel_id ] . can_receive ( amount_msat , check_frozen = True ) :
continue
edge_cost , fee_for_edge_msat = self . _edge_cost (
edge_channel_id ,
short_channel_id = edge_channel_id ,
start_node = edge_startnode ,
end_node = edge_endnode ,
payment_amt_msat = amount_msat ,
ignore_costs = ( edge_startnode == nodeA ) ,
is_mine = is_mine ,
my_channels = my_channels )
my_channels = my_channels ,
private_route_edges = private_route_edges )
alt_dist_to_neighbour = distance_from_start [ edge_endnode ] + edge_cost
if alt_dist_to_neighbour < distance_from_start [ edge_startnode ] :
distance_from_start [ edge_startnode ] = alt_dist_to_neighbour
prev_node [ edge_startnode ] = PathEdge ( node_id = edge_endnode ,
prev_node [ edge_startnode ] = PathEdge (
start_node = edge_startnode ,
end_node = edge_endnode ,
short_channel_id = ShortChannelID ( edge_channel_id ) )
amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore . put ( ( alt_dist_to_neighbour , amount_to_forward_msat , edge_startnode ) )
@ -260,11 +307,16 @@ class LNPathFinder(Logger):
return prev_node
@profiler
def find_path_for_payment ( self , nodeA : bytes , nodeB : bytes ,
invoice_amount_msat : int , * ,
def find_path_for_payment (
self ,
* ,
nodeA : bytes ,
nodeB : bytes ,
invoice_amount_msat : int ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
blacklist : Set [ ShortChannelID ] = None ) \
- > Optional [ LNPaymentPath ] :
blacklist : Set [ ShortChannelID ] = None ,
private_route_edges : Dict [ ShortChannelID , RouteEdge ] = None ,
) - > Optional [ LNPaymentPath ] :
""" Return a path from nodeA to nodeB. """
assert type ( nodeA ) is bytes
assert type ( nodeB ) is bytes
@ -272,7 +324,13 @@ class LNPathFinder(Logger):
if my_channels is None :
my_channels = { }
prev_node = self . get_distances ( nodeA , nodeB , invoice_amount_msat , my_channels = my_channels , blacklist = blacklist )
prev_node = self . get_distances (
nodeA = nodeA ,
nodeB = nodeB ,
invoice_amount_msat = invoice_amount_msat ,
my_channels = my_channels ,
blacklist = blacklist ,
private_route_edges = private_route_edges )
if nodeA not in prev_node :
return None # no path found
@ -287,34 +345,66 @@ class LNPathFinder(Logger):
edge_startnode = edge . node_id
return path
def create_route_from_path ( self , path : Optional [ LNPaymentPath ] , from_node_id : bytes , * ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ) - > LNPaymentRoute :
assert isinstance ( from_node_id , bytes )
def create_route_from_path (
self ,
path : Optional [ LNPaymentPath ] ,
* ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
private_route_edges : Dict [ ShortChannelID , RouteEdge ] = None ,
) - > LNPaymentRoute :
if path is None :
raise Exception ( ' cannot create route from None path ' )
if private_route_edges is None :
private_route_edges = { }
route = [ ]
prev_node_id = from_node_id
for edge in path :
node_id = edge . node_id
short_channel_id = edge . short_channel_id
prev_end_node = path [ 0 ] . start_node
for path_edge in path :
short_channel_id = path_edge . short_channel_id
_endnodes = self . channel_db . get_endnodes_for_chan ( short_channel_id , my_channels = my_channels )
if _endnodes and sorted ( _endnodes ) != sorted ( [ prev_node_id , node_id ] ) :
if _endnodes and sorted ( _endnodes ) != sorted ( [ path_edge . start_node , path_edge . end_node ] ) :
raise LNPathInconsistent ( " endpoints of edge inconsistent with short_channel_id " )
if path_edge . start_node != prev_end_node :
raise LNPathInconsistent ( " edges do not chain together " )
channel_policy = self . channel_db . get_policy_for_node ( short_channel_id = short_channel_id ,
node_id = prev_node_id ,
route_edge = private_route_edges . get ( short_channel_id , None )
if route_edge is None :
channel_policy = self . channel_db . get_policy_for_node (
short_channel_id = short_channel_id ,
node_id = path_edge . start_node ,
my_channels = my_channels )
if channel_policy is None :
raise NoChannelPolicy ( short_channel_id )
node_info = self . channel_db . get_node_info_for_node_id ( node_id = node_id )
route . append ( RouteEdge . from_channel_policy ( channel_policy , short_channel_id , node_id ,
node_info = node_info ) )
prev_node_id = node_id
node_info = self . channel_db . get_node_info_for_node_id ( node_id = path_edge . end_node )
route_edge = RouteEdge . from_channel_policy (
channel_policy = channel_policy ,
short_channel_id = short_channel_id ,
start_node = path_edge . start_node ,
end_node = path_edge . end_node ,
node_info = node_info )
route . append ( route_edge )
prev_end_node = path_edge . end_node
return route
def find_route ( self , nodeA : bytes , nodeB : bytes , invoice_amount_msat : int , * ,
path = None , my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
blacklist : Set [ ShortChannelID ] = None ) - > Optional [ LNPaymentRoute ] :
def find_route (
self ,
* ,
nodeA : bytes ,
nodeB : bytes ,
invoice_amount_msat : int ,
path = None ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
blacklist : Set [ ShortChannelID ] = None ,
private_route_edges : Dict [ ShortChannelID , RouteEdge ] = None ,
) - > Optional [ LNPaymentRoute ] :
route = None
if not path :
path = self . find_path_for_payment ( nodeA , nodeB , invoice_amount_msat , my_channels = my_channels , blacklist = blacklist )
path = self . find_path_for_payment (
nodeA = nodeA ,
nodeB = nodeB ,
invoice_amount_msat = invoice_amount_msat ,
my_channels = my_channels ,
blacklist = blacklist ,
private_route_edges = private_route_edges )
if path :
return self . create_route_from_path ( path , nodeA , my_channels = my_channels )
route = self . create_route_from_path (
path , my_channels = my_channels , private_route_edges = private_route_edges )
return route