Browse Source

lnrouter: fix self-payments

patch-4
bitromortac 4 years ago
parent
commit
e6ccbcf7b7
No known key found for this signature in database GPG Key ID: 1965063FC13BEBE2
  1. 2
      electrum/channel_db.py
  2. 47
      electrum/lnrouter.py

2
electrum/channel_db.py

@ -826,7 +826,7 @@ class ChannelDB(SqlDB):
*, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Set[bytes]: ) -> Set[ShortChannelID]:
"""Returns the set of short channel IDs where node_id is one of the channel participants.""" """Returns the set of short channel IDs where node_id is one of the channel participants."""
if not self.data_loaded.is_set(): if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!") raise Exception("channelDB data not loaded yet!")

47
electrum/lnrouter.py

@ -25,13 +25,13 @@
import queue import queue
from collections import defaultdict from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set from typing import Sequence, Tuple, Optional, Dict, TYPE_CHECKING, Set
import time import time
from threading import RLock from threading import RLock
import attr import attr
from math import inf from math import inf
from .util import bh2u, profiler, with_lock from .util import profiler, with_lock, bh2u
from .logging import Logger from .logging import Logger
from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures, from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE) NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE)
@ -511,7 +511,7 @@ class LNPathFinder(Logger):
overall_cost = fee_msat + cltv_cost + liquidity_penalty overall_cost = fee_msat + cltv_cost + liquidity_penalty
return overall_cost, fee_msat return overall_cost, fee_msat
def get_distances( def get_shortest_path_hops(
self, self,
*, *,
nodeA: bytes, nodeA: bytes,
@ -529,22 +529,34 @@ class LNPathFinder(Logger):
blacklist = self.liquidity_hints.get_blacklist() blacklist = self.liquidity_hints.get_blacklist()
distance_from_start = defaultdict(lambda: float('inf')) distance_from_start = defaultdict(lambda: float('inf'))
distance_from_start[nodeB] = 0 distance_from_start[nodeB] = 0
prev_node = {} # type: Dict[bytes, PathEdge] previous_hops = {} # type: Dict[bytes, PathEdge]
nodes_to_explore = queue.PriorityQueue() nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
# main loop of search # main loop of search
while nodes_to_explore.qsize() > 0: while nodes_to_explore.qsize() > 0:
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
if edge_endnode == nodeA: if edge_endnode == nodeA and previous_hops: # previous_hops check for circular paths
self.logger.info("found a path")
break break
if dist_to_edge_endnode != distance_from_start[edge_endnode]: if dist_to_edge_endnode != distance_from_start[edge_endnode]:
# queue.PriorityQueue does not implement decrease_priority, # queue.PriorityQueue does not implement decrease_priority,
# so instead of decreasing priorities, we add items again into the queue. # so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now: # so there are duplicates in the queue, that we discard now:
continue continue
for edge_channel_id in self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges): if nodeA == nodeB: # we want circular paths
if not previous_hops: # in the first node exploration step, we only take receiving channels
channels_for_endnode = self.channel_db.get_channels_for_node(
edge_endnode, my_channels={}, private_route_edges=private_route_edges)
else: # in the next steps, we only take sending channels
channels_for_endnode = self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges={})
else:
channels_for_endnode = self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges)
for edge_channel_id in channels_for_endnode:
assert isinstance(edge_channel_id, bytes) assert isinstance(edge_channel_id, bytes)
if blacklist and edge_channel_id in blacklist: if blacklist and edge_channel_id in blacklist:
continue continue
@ -558,10 +570,6 @@ class LNPathFinder(Logger):
if edge_startnode == nodeA: # payment outgoing, on our channel if edge_startnode == nodeA: # payment outgoing, on our channel
if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True): if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True):
continue continue
else: # payment incoming, on our channel. (funny business, cycle weirdness)
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
continue
edge_cost, fee_for_edge_msat = self._edge_cost( edge_cost, fee_for_edge_msat = self._edge_cost(
short_channel_id=edge_channel_id, short_channel_id=edge_channel_id,
start_node=edge_startnode, start_node=edge_startnode,
@ -574,14 +582,17 @@ class LNPathFinder(Logger):
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]: if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour distance_from_start[edge_startnode] = alt_dist_to_neighbour
prev_node[edge_startnode] = PathEdge( previous_hops[edge_startnode] = PathEdge(
start_node=edge_startnode, start_node=edge_startnode,
end_node=edge_endnode, end_node=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id)) short_channel_id=ShortChannelID(edge_channel_id))
amount_to_forward_msat = amount_msat + fee_for_edge_msat amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
# for circular paths, we already explored the end node, but this
return prev_node # is also our start node, so set it to unexplored
if edge_endnode == nodeB and nodeA == nodeB:
distance_from_start[edge_endnode] = float('inf')
return previous_hops
@profiler @profiler
def find_path_for_payment( def find_path_for_payment(
@ -600,22 +611,22 @@ class LNPathFinder(Logger):
if my_channels is None: if my_channels is None:
my_channels = {} my_channels = {}
prev_node = self.get_distances( previous_hops = self.get_shortest_path_hops(
nodeA=nodeA, nodeA=nodeA,
nodeB=nodeB, nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat, invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels, my_channels=my_channels,
private_route_edges=private_route_edges) private_route_edges=private_route_edges)
if nodeA not in prev_node: if nodeA not in previous_hops:
return None # no path found return None # no path found
# backtrack from search_end (nodeA) to search_start (nodeB) # backtrack from search_end (nodeA) to search_start (nodeB)
# FIXME paths cannot be longer than 20 edges (onion packet)... # FIXME paths cannot be longer than 20 edges (onion packet)...
edge_startnode = nodeA edge_startnode = nodeA
path = [] path = []
while edge_startnode != nodeB: while edge_startnode != nodeB or not path: # second condition for circular paths
edge = prev_node[edge_startnode] edge = previous_hops[edge_startnode]
path += [edge] path += [edge]
edge_startnode = edge.node_id edge_startnode = edge.node_id
return path return path

Loading…
Cancel
Save