Browse Source

lnrouter: fix self-payments

patch-4
bitromortac 4 years ago
parent
commit
e6ccbcf7b7
No known key found for this signature in database GPG Key ID: 1965063FC13BEBE2
  1. 2
      electrum/channel_db.py
  2. 47
      electrum/lnrouter.py

2
electrum/channel_db.py

@ -826,7 +826,7 @@ class ChannelDB(SqlDB):
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Set[bytes]:
) -> Set[ShortChannelID]:
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!")

47
electrum/lnrouter.py

@ -25,13 +25,13 @@
import queue
from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
from typing import Sequence, Tuple, Optional, Dict, TYPE_CHECKING, Set
import time
from threading import RLock
import attr
from math import inf
from .util import bh2u, profiler, with_lock
from .util import profiler, with_lock, bh2u
from .logging import Logger
from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE)
@ -511,7 +511,7 @@ class LNPathFinder(Logger):
overall_cost = fee_msat + cltv_cost + liquidity_penalty
return overall_cost, fee_msat
def get_distances(
def get_shortest_path_hops(
self,
*,
nodeA: bytes,
@ -529,22 +529,34 @@ class LNPathFinder(Logger):
blacklist = self.liquidity_hints.get_blacklist()
distance_from_start = defaultdict(lambda: float('inf'))
distance_from_start[nodeB] = 0
prev_node = {} # type: Dict[bytes, PathEdge]
previous_hops = {} # type: Dict[bytes, PathEdge]
nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
# main loop of search
while nodes_to_explore.qsize() > 0:
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
if edge_endnode == nodeA:
if edge_endnode == nodeA and previous_hops: # previous_hops check for circular paths
self.logger.info("found a path")
break
if dist_to_edge_endnode != distance_from_start[edge_endnode]:
# queue.PriorityQueue does not implement decrease_priority,
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
if nodeA == nodeB: # we want circular paths
if not previous_hops: # in the first node exploration step, we only take receiving channels
channels_for_endnode = self.channel_db.get_channels_for_node(
edge_endnode, my_channels={}, private_route_edges=private_route_edges)
else: # in the next steps, we only take sending channels
channels_for_endnode = self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges={})
else:
channels_for_endnode = self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges)
for edge_channel_id in channels_for_endnode:
assert isinstance(edge_channel_id, bytes)
if blacklist and edge_channel_id in blacklist:
continue
@ -558,10 +570,6 @@ class LNPathFinder(Logger):
if edge_startnode == nodeA: # payment outgoing, on our channel
if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True):
continue
else: # payment incoming, on our channel. (funny business, cycle weirdness)
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
continue
edge_cost, fee_for_edge_msat = self._edge_cost(
short_channel_id=edge_channel_id,
start_node=edge_startnode,
@ -574,14 +582,17 @@ class LNPathFinder(Logger):
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
prev_node[edge_startnode] = PathEdge(
previous_hops[edge_startnode] = PathEdge(
start_node=edge_startnode,
end_node=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id))
amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
return prev_node
# for circular paths, we already explored the end node, but this
# is also our start node, so set it to unexplored
if edge_endnode == nodeB and nodeA == nodeB:
distance_from_start[edge_endnode] = float('inf')
return previous_hops
@profiler
def find_path_for_payment(
@ -600,22 +611,22 @@ class LNPathFinder(Logger):
if my_channels is None:
my_channels = {}
prev_node = self.get_distances(
previous_hops = self.get_shortest_path_hops(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
private_route_edges=private_route_edges)
if nodeA not in prev_node:
if nodeA not in previous_hops:
return None # no path found
# backtrack from search_end (nodeA) to search_start (nodeB)
# FIXME paths cannot be longer than 20 edges (onion packet)...
edge_startnode = nodeA
path = []
while edge_startnode != nodeB:
edge = prev_node[edge_startnode]
while edge_startnode != nodeB or not path: # second condition for circular paths
edge = previous_hops[edge_startnode]
path += [edge]
edge_startnode = edge.node_id
return path

Loading…
Cancel
Save