diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 6e62d1e7d..81e498c85 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -135,24 +135,12 @@ def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool: return False -BLACKLIST_DURATION = 3600 class LNPathFinder(Logger): def __init__(self, channel_db: ChannelDB): Logger.__init__(self) self.channel_db = channel_db - self.blacklist = dict() # short_chan_id -> timestamp - - def add_to_blacklist(self, short_channel_id: ShortChannelID): - self.logger.info(f'blacklisting channel {short_channel_id}') - now = int(time.time()) - self.blacklist[short_channel_id] = now - - def is_blacklisted(self, short_channel_id: ShortChannelID) -> bool: - now = int(time.time()) - t = self.blacklist.get(short_channel_id, 0) - return now - t < BLACKLIST_DURATION def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, payment_amt_msat: int, ignore_costs=False, is_mine=False, *, @@ -200,10 +188,9 @@ class LNPathFinder(Logger): overall_cost = base_cost + fee_msat + cltv_cost return overall_cost, fee_msat - def get_distances(self, nodeA: bytes, nodeB: bytes, - invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None - ) -> Dict[bytes, PathEdge]: + def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]: # note: we don't lock self.channel_db, so while the path finding runs, # the underlying graph could potentially change... (not good but maybe ~OK?) @@ -216,7 +203,6 @@ class LNPathFinder(Logger): nodes_to_explore = queue.PriorityQueue() nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! - # main loop of search while nodes_to_explore.qsize() > 0: dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() @@ -229,7 +215,7 @@ class LNPathFinder(Logger): continue for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels): assert isinstance(edge_channel_id, bytes) - if self.is_blacklisted(edge_channel_id): + if blacklist and edge_channel_id in blacklist: continue channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels) edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id @@ -263,7 +249,8 @@ class LNPathFinder(Logger): @profiler def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) \ + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None) \ -> Optional[LNPaymentPath]: """Return a path from nodeA to nodeB.""" assert type(nodeA) is bytes @@ -272,7 +259,7 @@ class LNPathFinder(Logger): if my_channels is None: my_channels = {} - prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels) + prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) if nodeA not in prev_node: return None # no path found @@ -312,8 +299,9 @@ class LNPathFinder(Logger): return route def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[LNPaymentRoute]: + path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]: if not path: - path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels) + path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) if path: return self.create_route_from_path(path, nodeA, my_channels=my_channels) diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 499bdd81d..0b4a09ceb 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -8,7 +8,7 @@ import json from collections import namedtuple, defaultdict from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence import re - +import time import attr from aiorpcx import NetAddress @@ -1313,3 +1313,17 @@ class OnionFailureCodeMetaFlag(IntFlag): NODE = 0x2000 UPDATE = 0x1000 + +class ChannelBlackList: + + def __init__(self): + self.blacklist = dict() # short_chan_id -> timestamp + + def add(self, short_channel_id: ShortChannelID): + now = int(time.time()) + self.blacklist[short_channel_id] = now + + def get_current_list(self) -> Set[ShortChannelID]: + BLACKLIST_DURATION = 3600 + now = int(time.time()) + return set(k for k, t in self.blacklist.items() if now - t < BLACKLIST_DURATION) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 08b27f98c..a8c3c6375 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -7,7 +7,7 @@ import os from decimal import Decimal import random import time -from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any +from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any import threading import socket import aiohttp @@ -540,6 +540,7 @@ class LNGossip(LNWorker): if categorized_chan_upds.good: self.logger.debug(f'on_channel_update: {len(categorized_chan_upds.good)}/{len(chan_upds_chunk)}') + class LNWallet(LNWorker): lnwatcher: Optional['LNWalletWatcher'] @@ -1014,7 +1015,8 @@ class LNWallet(LNWorker): except IndexError: self.logger.info("payment destination reported error") else: - self.network.path_finder.add_to_blacklist(short_chan_id) + self.logger.info(f'blacklisting channel {short_channel_id}') + self.network.channel_blacklist.add(short_chan_id) else: # probably got "update_fail_malformed_htlc". well... who to penalise now? assert payment_attempt.failure_message is not None @@ -1127,6 +1129,7 @@ class LNWallet(LNWorker): channels = list(self.channels.values()) scid_to_my_channels = {chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None} + blacklist = self.network.channel_blacklist.get_current_list() for private_route in r_tags: if len(private_route) == 0: continue @@ -1144,7 +1147,7 @@ class LNWallet(LNWorker): try: route = self.network.path_finder.find_route( self.node_keypair.pubkey, border_node_pubkey, amount_msat, - path=path, my_channels=scid_to_my_channels) + path=path, my_channels=scid_to_my_channels, blacklist=blacklist) except NoChannelPolicy: continue if not route: @@ -1186,7 +1189,7 @@ class LNWallet(LNWorker): if route is None: route = self.network.path_finder.find_route( self.node_keypair.pubkey, invoice_pubkey, amount_msat, - path=full_path, my_channels=scid_to_my_channels) + path=full_path, my_channels=scid_to_my_channels, blacklist=blacklist) if not route: raise NoPathFound() if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): diff --git a/electrum/network.py b/electrum/network.py index 328f074e0..d331011a5 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -45,7 +45,6 @@ from . import util from .util import (log_exceptions, ignore_exceptions, bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter, is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager) - from .bitcoin import COIN from . import constants from . import blockchain @@ -60,6 +59,7 @@ from .version import PROTOCOL_VERSION from .simple_config import SimpleConfig from .i18n import _ from .logging import get_logger, Logger +from .lnutil import ChannelBlackList if TYPE_CHECKING: from .channel_db import ChannelDB @@ -335,6 +335,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): self._has_ever_managed_to_connect_to_server = False # lightning network + self.channel_blacklist = ChannelBlackList() self.channel_db = None # type: Optional[ChannelDB] self.lngossip = None # type: Optional[LNGossip] self.local_watchtower = None # type: Optional[WatchTower] diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 291bacc49..a057f1865 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -32,6 +32,7 @@ from electrum.lnmsg import encode_msg, decode_msg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID from electrum.lnonion import OnionFailureCode +from electrum.lnutil import ChannelBlackList from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -62,6 +63,7 @@ class MockNetwork: self.path_finder = LNPathFinder(self.channel_db) self.tx_queue = tx_queue self._blockchain = MockBlockchain() + self.channel_blacklist = ChannelBlackList() @property def callback_lock(self):