diff --git a/electrum/channel_db.py b/electrum/channel_db.py index 65809c596..8eeee39a6 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -826,7 +826,7 @@ class ChannelDB(SqlDB): *, my_channels: Dict[ShortChannelID, 'Channel'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, - ) -> Set[bytes]: + ) -> Set[ShortChannelID]: """Returns the set of short channel IDs where node_id is one of the channel participants.""" if not self.data_loaded.is_set(): raise Exception("channelDB data not loaded yet!") diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 22d6beed3..d9e39efed 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -25,13 +25,13 @@ import queue from collections import defaultdict -from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set +from typing import Sequence, Tuple, Optional, Dict, TYPE_CHECKING, Set import time from threading import RLock import attr from math import inf -from .util import bh2u, profiler, with_lock +from .util import profiler, with_lock, bh2u from .logging import Logger from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures, NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE) @@ -511,7 +511,7 @@ class LNPathFinder(Logger): overall_cost = fee_msat + cltv_cost + liquidity_penalty return overall_cost, fee_msat - def get_distances( + def get_shortest_path_hops( self, *, nodeA: bytes, @@ -529,22 +529,34 @@ class LNPathFinder(Logger): blacklist = self.liquidity_hints.get_blacklist() distance_from_start = defaultdict(lambda: float('inf')) distance_from_start[nodeB] = 0 - prev_node = {} # type: Dict[bytes, PathEdge] + previous_hops = {} # type: Dict[bytes, PathEdge] nodes_to_explore = queue.PriorityQueue() nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! # main loop of search while nodes_to_explore.qsize() > 0: dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() - if edge_endnode == nodeA: + if edge_endnode == nodeA and previous_hops: # previous_hops check for circular paths + self.logger.info("found a path") break if dist_to_edge_endnode != distance_from_start[edge_endnode]: # queue.PriorityQueue does not implement decrease_priority, # so instead of decreasing priorities, we add items again into the queue. # so there are duplicates in the queue, that we discard now: continue - for edge_channel_id in self.channel_db.get_channels_for_node( - edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges): + + if nodeA == nodeB: # we want circular paths + if not previous_hops: # in the first node exploration step, we only take receiving channels + channels_for_endnode = self.channel_db.get_channels_for_node( + edge_endnode, my_channels={}, private_route_edges=private_route_edges) + else: # in the next steps, we only take sending channels + channels_for_endnode = self.channel_db.get_channels_for_node( + edge_endnode, my_channels=my_channels, private_route_edges={}) + else: + channels_for_endnode = self.channel_db.get_channels_for_node( + edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges) + + for edge_channel_id in channels_for_endnode: assert isinstance(edge_channel_id, bytes) if blacklist and edge_channel_id in blacklist: continue @@ -558,10 +570,6 @@ class LNPathFinder(Logger): if edge_startnode == nodeA: # payment outgoing, on our channel if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True): continue - else: # payment incoming, on our channel. (funny business, cycle weirdness) - assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode)) - if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True): - continue edge_cost, fee_for_edge_msat = self._edge_cost( short_channel_id=edge_channel_id, start_node=edge_startnode, @@ -574,14 +582,17 @@ class LNPathFinder(Logger): alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour - prev_node[edge_startnode] = PathEdge( + previous_hops[edge_startnode] = PathEdge( start_node=edge_startnode, end_node=edge_endnode, short_channel_id=ShortChannelID(edge_channel_id)) amount_to_forward_msat = amount_msat + fee_for_edge_msat nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) - - return prev_node + # for circular paths, we already explored the end node, but this + # is also our start node, so set it to unexplored + if edge_endnode == nodeB and nodeA == nodeB: + distance_from_start[edge_endnode] = float('inf') + return previous_hops @profiler def find_path_for_payment( @@ -600,22 +611,22 @@ class LNPathFinder(Logger): if my_channels is None: my_channels = {} - prev_node = self.get_distances( + previous_hops = self.get_shortest_path_hops( nodeA=nodeA, nodeB=nodeB, invoice_amount_msat=invoice_amount_msat, my_channels=my_channels, private_route_edges=private_route_edges) - if nodeA not in prev_node: + if nodeA not in previous_hops: return None # no path found # backtrack from search_end (nodeA) to search_start (nodeB) # FIXME paths cannot be longer than 20 edges (onion packet)... edge_startnode = nodeA path = [] - while edge_startnode != nodeB: - edge = prev_node[edge_startnode] + while edge_startnode != nodeB or not path: # second condition for circular paths + edge = previous_hops[edge_startnode] path += [edge] edge_startnode = edge.node_id return path