|
|
@ -135,24 +135,12 @@ def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool: |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
BLACKLIST_DURATION = 3600 |
|
|
|
|
|
|
|
class LNPathFinder(Logger): |
|
|
|
|
|
|
|
def __init__(self, channel_db: ChannelDB): |
|
|
|
Logger.__init__(self) |
|
|
|
self.channel_db = channel_db |
|
|
|
self.blacklist = dict() # short_chan_id -> timestamp |
|
|
|
|
|
|
|
def add_to_blacklist(self, short_channel_id: ShortChannelID): |
|
|
|
self.logger.info(f'blacklisting channel {short_channel_id}') |
|
|
|
now = int(time.time()) |
|
|
|
self.blacklist[short_channel_id] = now |
|
|
|
|
|
|
|
def is_blacklisted(self, short_channel_id: ShortChannelID) -> bool: |
|
|
|
now = int(time.time()) |
|
|
|
t = self.blacklist.get(short_channel_id, 0) |
|
|
|
return now - t < BLACKLIST_DURATION |
|
|
|
|
|
|
|
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, |
|
|
|
payment_amt_msat: int, ignore_costs=False, is_mine=False, *, |
|
|
@ -200,10 +188,9 @@ class LNPathFinder(Logger): |
|
|
|
overall_cost = base_cost + fee_msat + cltv_cost |
|
|
|
return overall_cost, fee_msat |
|
|
|
|
|
|
|
def get_distances(self, nodeA: bytes, nodeB: bytes, |
|
|
|
invoice_amount_msat: int, *, |
|
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None |
|
|
|
) -> Dict[bytes, PathEdge]: |
|
|
|
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, |
|
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None, |
|
|
|
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]: |
|
|
|
# note: we don't lock self.channel_db, so while the path finding runs, |
|
|
|
# the underlying graph could potentially change... (not good but maybe ~OK?) |
|
|
|
|
|
|
@ -216,7 +203,6 @@ class LNPathFinder(Logger): |
|
|
|
nodes_to_explore = queue.PriorityQueue() |
|
|
|
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! |
|
|
|
|
|
|
|
|
|
|
|
# main loop of search |
|
|
|
while nodes_to_explore.qsize() > 0: |
|
|
|
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() |
|
|
@ -229,7 +215,7 @@ class LNPathFinder(Logger): |
|
|
|
continue |
|
|
|
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels): |
|
|
|
assert isinstance(edge_channel_id, bytes) |
|
|
|
if self.is_blacklisted(edge_channel_id): |
|
|
|
if blacklist and edge_channel_id in blacklist: |
|
|
|
continue |
|
|
|
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels) |
|
|
|
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id |
|
|
@ -263,7 +249,8 @@ class LNPathFinder(Logger): |
|
|
|
@profiler |
|
|
|
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, |
|
|
|
invoice_amount_msat: int, *, |
|
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) \ |
|
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None, |
|
|
|
blacklist: Set[ShortChannelID] = None) \ |
|
|
|
-> Optional[LNPaymentPath]: |
|
|
|
"""Return a path from nodeA to nodeB.""" |
|
|
|
assert type(nodeA) is bytes |
|
|
@ -272,7 +259,7 @@ class LNPathFinder(Logger): |
|
|
|
if my_channels is None: |
|
|
|
my_channels = {} |
|
|
|
|
|
|
|
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels) |
|
|
|
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) |
|
|
|
|
|
|
|
if nodeA not in prev_node: |
|
|
|
return None # no path found |
|
|
@ -312,8 +299,9 @@ class LNPathFinder(Logger): |
|
|
|
return route |
|
|
|
|
|
|
|
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, |
|
|
|
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[LNPaymentRoute]: |
|
|
|
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None, |
|
|
|
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]: |
|
|
|
if not path: |
|
|
|
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels) |
|
|
|
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) |
|
|
|
if path: |
|
|
|
return self.create_route_from_path(path, nodeA, my_channels=my_channels) |
|
|
|