diff --git a/electrum/channel_db.py b/electrum/channel_db.py index aa78f14c2..65b4cedf5 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -48,6 +48,7 @@ from .lnmsg import decode_msg if TYPE_CHECKING: from .network import Network from .lnchannel import Channel + from .lnrouter import RouteEdge FLAG_DISABLE = 1 << 1 @@ -81,6 +82,16 @@ class ChannelInfo(NamedTuple): payload_dict = decode_msg(raw)[1] return ChannelInfo.from_msg(payload_dict) + @staticmethod + def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo': + node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node]) + return ChannelInfo( + short_channel_id=route_edge.short_channel_id, + node1_id=node1_id, + node2_id=node2_id, + capacity_sat=None, + ) + class Policy(NamedTuple): key: bytes @@ -113,6 +124,20 @@ class Policy(NamedTuple): payload['start_node'] = key[8:] return Policy.from_msg(payload) + @staticmethod + def from_route_edge(route_edge: 'RouteEdge') -> 'Policy': + return Policy( + key=route_edge.short_channel_id + route_edge.start_node, + cltv_expiry_delta=route_edge.cltv_expiry_delta, + htlc_minimum_msat=0, + htlc_maximum_msat=None, + fee_base_msat=route_edge.fee_base_msat, + fee_proportional_millionths=route_edge.fee_proportional_millionths, + channel_flags=0, + message_flags=0, + timestamp=0, + ) + def is_disabled(self): return self.channel_flags & FLAG_DISABLE @@ -216,6 +241,8 @@ class CategorizedChannelUpdates(NamedTuple): def get_mychannel_info(short_channel_id: ShortChannelID, my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]: chan = my_channels.get(short_channel_id) + if not chan: + return ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs()) return ci._replace(capacity_sat=chan.constraints.capacity) @@ -724,8 +751,14 @@ class ChannelDB(SqlDB): nchans_with_2p = len(self._chans_with_2_policies) return nchans_with_0p, nchans_with_1p, nchans_with_2p - def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: + def get_policy_for_node( + self, + short_channel_id: bytes, + node_id: bytes, + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + ) -> Optional['Policy']: channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: # publicly announced channel policy = self._policies.get((node_id, short_channel_id)) @@ -737,28 +770,56 @@ class ChannelDB(SqlDB): return Policy.from_msg(chan_upd_dict) # check if it's one of our own channels if my_channels: - return get_mychannel_policy(short_channel_id, node_id, my_channels) - - def get_channel_info(self, short_channel_id: ShortChannelID, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]: + policy = get_mychannel_policy(short_channel_id, node_id, my_channels) + if policy: + return policy + if private_route_edges: + route_edge = private_route_edges.get(short_channel_id, None) + if route_edge: + return Policy.from_route_edge(route_edge) + + def get_channel_info( + self, + short_channel_id: ShortChannelID, + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + ) -> Optional[ChannelInfo]: ret = self._channels.get(short_channel_id) if ret: return ret # check if it's one of our own channels if my_channels: - return get_mychannel_info(short_channel_id, my_channels) - - def get_channels_for_node(self, node_id: bytes, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]: + channel_info = get_mychannel_info(short_channel_id, my_channels) + if channel_info: + return channel_info + if private_route_edges: + route_edge = private_route_edges.get(short_channel_id) + if route_edge: + return ChannelInfo.from_route_edge(route_edge) + + def get_channels_for_node( + self, + node_id: bytes, + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + ) -> Set[bytes]: """Returns the set of short channel IDs where node_id is one of the channel participants.""" if not self.data_loaded.is_set(): raise Exception("channelDB data not loaded yet!") relevant_channels = self._channels_for_node.get(node_id) or set() relevant_channels = set(relevant_channels) # copy # add our own channels # TODO maybe slow? - for chan in (my_channels.values() or []): - if node_id in (chan.node_id, chan.get_local_pubkey()): - relevant_channels.add(chan.short_channel_id) + if my_channels: + for chan in my_channels.values(): + if node_id in (chan.node_id, chan.get_local_pubkey()): + relevant_channels.add(chan.short_channel_id) + # add private channels # TODO maybe slow? + if private_route_edges: + for route_edge in private_route_edges.values(): + if node_id in (route_edge.start_node, route_edge.end_node): + relevant_channels.add(route_edge.short_channel_id) return relevant_channels def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *, diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 59a674f7d..809e981b4 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor @attr.s(slots=True) class PathEdge: - """if you travel through short_channel_id, you will reach node_id""" - node_id = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) + start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) + end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val)) + @property + def node_id(self) -> bytes: + # legacy compat # TODO rm + return self.end_node @attr.s class RouteEdge(PathEdge): @@ -73,17 +77,26 @@ class RouteEdge(PathEdge): fee_proportional_millionths=self.fee_proportional_millionths) @classmethod - def from_channel_policy(cls, channel_policy: 'Policy', - short_channel_id: bytes, end_node: bytes, *, - node_info: Optional[NodeInfo]) -> 'RouteEdge': + def from_channel_policy( + cls, + *, + channel_policy: 'Policy', + short_channel_id: bytes, + start_node: bytes, + end_node: bytes, + node_info: Optional[NodeInfo], # for end_node + ) -> 'RouteEdge': assert isinstance(short_channel_id, bytes) + assert type(start_node) is bytes assert type(end_node) is bytes - return RouteEdge(node_id=end_node, - short_channel_id=ShortChannelID.normalize(short_channel_id), - fee_base_msat=channel_policy.fee_base_msat, - fee_proportional_millionths=channel_policy.fee_proportional_millionths, - cltv_expiry_delta=channel_policy.cltv_expiry_delta, - node_features=node_info.features if node_info else 0) + return RouteEdge( + start_node=start_node, + end_node=end_node, + short_channel_id=ShortChannelID.normalize(short_channel_id), + fee_base_msat=channel_policy.fee_base_msat, + fee_proportional_millionths=channel_policy.fee_proportional_millionths, + cltv_expiry_delta=channel_policy.cltv_expiry_delta, + node_features=node_info.features if node_info else 0) def is_sane_to_use(self, amount_msat: int) -> bool: # TODO revise ad-hoc heuristics @@ -155,21 +168,37 @@ class LNPathFinder(Logger): Logger.__init__(self) self.channel_db = channel_db - def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, - payment_amt_msat: int, ignore_costs=False, is_mine=False, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]: + def _edge_cost( + self, + *, + short_channel_id: bytes, + start_node: bytes, + end_node: bytes, + payment_amt_msat: int, + ignore_costs=False, + is_mine=False, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Tuple[float, int]: """Heuristic cost (distance metric) of going through a channel. Returns (heuristic_cost, fee_for_edge_msat). """ - channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels) + if private_route_edges is None: + private_route_edges = {} + channel_info = self.channel_db.get_channel_info( + short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges) if channel_info is None: return float('inf'), 0 - channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels) + channel_policy = self.channel_db.get_policy_for_node( + short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges) if channel_policy is None: return float('inf'), 0 # channels that did not publish both policies often return temporary channel failure - if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \ - and not is_mine: + channel_policy_backwards = self.channel_db.get_policy_for_node( + short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges) + if (channel_policy_backwards is None + and not is_mine + and short_channel_id not in private_route_edges): return float('inf'), 0 if channel_policy.is_disabled(): return float('inf'), 0 @@ -181,9 +210,15 @@ class LNPathFinder(Logger): if channel_policy.htlc_maximum_msat is not None and \ payment_amt_msat > channel_policy.htlc_maximum_msat: return float('inf'), 0 # payment amount too large - node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) - route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node, - node_info=node_info) + route_edge = private_route_edges.get(short_channel_id, None) + if route_edge is None: + node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) + route_edge = RouteEdge.from_channel_policy( + channel_policy=channel_policy, + short_channel_id=short_channel_id, + start_node=start_node, + end_node=end_node, + node_info=node_info) if not route_edge.is_sane_to_use(payment_amt_msat): return float('inf'), 0 # thanks but no thanks @@ -201,9 +236,16 @@ class LNPathFinder(Logger): overall_cost = base_cost + fee_msat + cltv_cost return overall_cost, fee_msat - def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None, - blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]: + def get_distances( + self, + *, + nodeA: bytes, + nodeB: bytes, + invoice_amount_msat: int, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Dict[bytes, PathEdge]: # note: we don't lock self.channel_db, so while the path finding runs, # the underlying graph could potentially change... (not good but maybe ~OK?) @@ -226,11 +268,13 @@ class LNPathFinder(Logger): # so instead of decreasing priorities, we add items again into the queue. # so there are duplicates in the queue, that we discard now: continue - for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels): + for edge_channel_id in self.channel_db.get_channels_for_node( + edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges): assert isinstance(edge_channel_id, bytes) if blacklist and edge_channel_id in blacklist: continue - channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels) + channel_info = self.channel_db.get_channel_info( + edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges) edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id is_mine = edge_channel_id in my_channels if is_mine: @@ -242,29 +286,37 @@ class LNPathFinder(Logger): if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True): continue edge_cost, fee_for_edge_msat = self._edge_cost( - edge_channel_id, + short_channel_id=edge_channel_id, start_node=edge_startnode, end_node=edge_endnode, payment_amt_msat=amount_msat, ignore_costs=(edge_startnode == nodeA), is_mine=is_mine, - my_channels=my_channels) + my_channels=my_channels, + private_route_edges=private_route_edges) alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour - prev_node[edge_startnode] = PathEdge(node_id=edge_endnode, - short_channel_id=ShortChannelID(edge_channel_id)) + prev_node[edge_startnode] = PathEdge( + start_node=edge_startnode, + end_node=edge_endnode, + short_channel_id=ShortChannelID(edge_channel_id)) amount_to_forward_msat = amount_msat + fee_for_edge_msat nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) return prev_node @profiler - def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, - invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None, - blacklist: Set[ShortChannelID] = None) \ - -> Optional[LNPaymentPath]: + def find_path_for_payment( + self, + *, + nodeA: bytes, + nodeB: bytes, + invoice_amount_msat: int, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Optional[LNPaymentPath]: """Return a path from nodeA to nodeB.""" assert type(nodeA) is bytes assert type(nodeB) is bytes @@ -272,7 +324,13 @@ class LNPathFinder(Logger): if my_channels is None: my_channels = {} - prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) + prev_node = self.get_distances( + nodeA=nodeA, + nodeB=nodeB, + invoice_amount_msat=invoice_amount_msat, + my_channels=my_channels, + blacklist=blacklist, + private_route_edges=private_route_edges) if nodeA not in prev_node: return None # no path found @@ -287,34 +345,66 @@ class LNPathFinder(Logger): edge_startnode = edge.node_id return path - def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute: - assert isinstance(from_node_id, bytes) + def create_route_from_path( + self, + path: Optional[LNPaymentPath], + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> LNPaymentRoute: if path is None: raise Exception('cannot create route from None path') + if private_route_edges is None: + private_route_edges = {} route = [] - prev_node_id = from_node_id - for edge in path: - node_id = edge.node_id - short_channel_id = edge.short_channel_id + prev_end_node = path[0].start_node + for path_edge in path: + short_channel_id = path_edge.short_channel_id _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels) - if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]): + if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]): + raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id") + if path_edge.start_node != prev_end_node: raise LNPathInconsistent("edges do not chain together") - channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id, - node_id=prev_node_id, - my_channels=my_channels) - if channel_policy is None: - raise NoChannelPolicy(short_channel_id) - node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id) - route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id, - node_info=node_info)) - prev_node_id = node_id + route_edge = private_route_edges.get(short_channel_id, None) + if route_edge is None: + channel_policy = self.channel_db.get_policy_for_node( + short_channel_id=short_channel_id, + node_id=path_edge.start_node, + my_channels=my_channels) + if channel_policy is None: + raise NoChannelPolicy(short_channel_id) + node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node) + route_edge = RouteEdge.from_channel_policy( + channel_policy=channel_policy, + short_channel_id=short_channel_id, + start_node=path_edge.start_node, + end_node=path_edge.end_node, + node_info=node_info) + route.append(route_edge) + prev_end_node = path_edge.end_node return route - def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None, - blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]: + def find_route( + self, + *, + nodeA: bytes, + nodeB: bytes, + invoice_amount_msat: int, + path = None, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Optional[LNPaymentRoute]: + route = None if not path: - path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) + path = self.find_path_for_payment( + nodeA=nodeA, + nodeB=nodeB, + invoice_amount_msat=invoice_amount_msat, + my_channels=my_channels, + blacklist=blacklist, + private_route_edges=private_route_edges) if path: - return self.create_route_from_path(path, nodeA, my_channels=my_channels) + route = self.create_route_from_path( + path, my_channels=my_channels, private_route_edges=private_route_edges) + return route diff --git a/electrum/lnworker.py b/electrum/lnworker.py index bdd91478a..9a95c331d 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1320,6 +1320,7 @@ class LNWallet(LNWorker): amount_msat=amount_msat, bucket_amount_msat=amount_msat, min_cltv_expiry=min_cltv_expiry, + my_pubkey=self.node_keypair.pubkey, invoice_pubkey=invoice_pubkey, invoice_features=invoice_features, node_id=chan.node_id, @@ -1336,7 +1337,8 @@ class LNWallet(LNWorker): continue route = [ RouteEdge( - node_id=chan.node_id, + start_node=self.node_keypair.pubkey, + end_node=chan.node_id, short_channel_id=chan.short_channel_id, fee_base_msat=0, fee_proportional_millionths=0, @@ -1383,6 +1385,7 @@ class LNWallet(LNWorker): amount_msat=amount_msat, bucket_amount_msat=bucket_amount_msat, min_cltv_expiry=min_cltv_expiry, + my_pubkey=self.node_keypair.pubkey, invoice_pubkey=invoice_pubkey, invoice_features=invoice_features, node_id=node_id, @@ -1404,7 +1407,8 @@ class LNWallet(LNWorker): trampoline_fee -= delta_fee route = [ RouteEdge( - node_id=node_id, + start_node=self.node_keypair.pubkey, + end_node=node_id, short_channel_id=chan.short_channel_id, fee_base_msat=0, fee_proportional_millionths=0, @@ -1447,77 +1451,65 @@ class LNWallet(LNWorker): full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]: channels = [outgoing_channel] if outgoing_channel else list(self.channels.values()) - route = None scid_to_my_channels = { chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None } blacklist = self.network.channel_blacklist.get_current_list() - # first try with routing hints, then without - for private_path in r_tags + [[]]: - private_route = [] - amount_for_node = amount_msat - path = full_path - if len(private_path) > NUM_MAX_EDGES_IN_PAYMENT_PATH: - continue - if len(private_path) == 0: - border_node_pubkey = invoice_pubkey - else: - border_node_pubkey = private_path[0][0] - # we need to shift the node pubkey by one towards the destination: - private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey] - private_path_rest = [edge[1:] for edge in private_path] - prev_node_id = border_node_pubkey - for node_pubkey, edge_rest in zip(private_path_nodes, private_path_rest): - short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest - short_channel_id = ShortChannelID(short_channel_id) - # if we have a routing policy for this edge in the db, that takes precedence, - # as it is likely from a previous failure - channel_policy = self.channel_db.get_policy_for_node( + # Collect all private edges from route hints. + # Note: if some route hints are multiple edges long, and these paths cross each other, + # we allow our path finding to cross the paths; i.e. the route hints are not isolated. + private_route_edges = {} # type: Dict[ShortChannelID, RouteEdge] + for private_path in r_tags: + # we need to shift the node pubkey by one towards the destination: + private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey] + private_path_rest = [edge[1:] for edge in private_path] + start_node = private_path[0][0] + for end_node, edge_rest in zip(private_path_nodes, private_path_rest): + short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest + short_channel_id = ShortChannelID(short_channel_id) + # if we have a routing policy for this edge in the db, that takes precedence, + # as it is likely from a previous failure + channel_policy = self.channel_db.get_policy_for_node( + short_channel_id=short_channel_id, + node_id=start_node, + my_channels=scid_to_my_channels) + if channel_policy: + fee_base_msat = channel_policy.fee_base_msat + fee_proportional_millionths = channel_policy.fee_proportional_millionths + cltv_expiry_delta = channel_policy.cltv_expiry_delta + node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) + route_edge = RouteEdge( + start_node=start_node, + end_node=end_node, short_channel_id=short_channel_id, - node_id=prev_node_id, - my_channels=scid_to_my_channels) - if channel_policy: - fee_base_msat = channel_policy.fee_base_msat - fee_proportional_millionths = channel_policy.fee_proportional_millionths - cltv_expiry_delta = channel_policy.cltv_expiry_delta - node_info = self.channel_db.get_node_info_for_node_id(node_id=node_pubkey) - private_route.append( - RouteEdge( - node_id=node_pubkey, - short_channel_id=short_channel_id, - fee_base_msat=fee_base_msat, - fee_proportional_millionths=fee_proportional_millionths, - cltv_expiry_delta=cltv_expiry_delta, - node_features=node_info.features if node_info else 0)) - prev_node_id = node_pubkey - for edge in private_route[::-1]: - amount_for_node += edge.fee_for_edge(amount_for_node) - if full_path: - # user pre-selected path. check that end of given path coincides with private_route: - if [edge.short_channel_id for edge in full_path[-len(private_path):]] != [edge[1] for edge in private_path]: - continue - path = full_path[:-len(private_path)] - if any(edge.short_channel_id in blacklist for edge in private_route): - continue - try: - route = self.network.path_finder.find_route( - self.node_keypair.pubkey, border_node_pubkey, amount_for_node, - path=path, my_channels=scid_to_my_channels, blacklist=blacklist) - except NoChannelPolicy: - continue - if not route: - continue - route = route + private_route - # test sanity - if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): - self.logger.info(f"rejecting insane route {route}") - continue - break - else: + fee_base_msat=fee_base_msat, + fee_proportional_millionths=fee_proportional_millionths, + cltv_expiry_delta=cltv_expiry_delta, + node_features=node_info.features if node_info else 0) + if route_edge.short_channel_id not in blacklist: + private_route_edges[route_edge.short_channel_id] = route_edge + start_node = end_node + # now find a route, end to end: between us and the recipient + try: + route = self.network.path_finder.find_route( + nodeA=self.node_keypair.pubkey, + nodeB=invoice_pubkey, + invoice_amount_msat=amount_msat, + path=full_path, + my_channels=scid_to_my_channels, + blacklist=blacklist, + private_route_edges=private_route_edges) + except NoChannelPolicy as e: + raise NoPathFound() from e + if not route: + raise NoPathFound() + # test sanity + if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): + self.logger.info(f"rejecting insane route {route}") raise NoPathFound() assert len(route) > 0 - if route[-1].node_id != invoice_pubkey: + if route[-1].end_node != invoice_pubkey: raise LNPathInconsistent("last node_id != invoice pubkey") # add features from invoice route[-1].node_features |= invoice_features diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 46e832ceb..9fb07c97d 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -600,17 +600,27 @@ class TestPeer(ElectrumTestCase): peers = graph.all_peers() async def pay(pay_req): with self.subTest(msg="bad path: edges do not chain together"): - path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), - PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] + path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, + end_node=graph.w_c.node_keypair.pubkey, + short_channel_id=graph.chan_ab.short_channel_id), + PathEdge(start_node=graph.w_b.node_keypair.pubkey, + end_node=graph.w_d.node_keypair.pubkey, + short_channel_id=graph.chan_bd.short_channel_id)] with self.assertRaises(LNPathInconsistent): await graph.w_a.pay_invoice(pay_req, full_path=path) with self.subTest(msg="bad path: last node id differs from invoice pubkey"): - path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)] + path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, + end_node=graph.w_b.node_keypair.pubkey, + short_channel_id=graph.chan_ab.short_channel_id)] with self.assertRaises(LNPathInconsistent): await graph.w_a.pay_invoice(pay_req, full_path=path) with self.subTest(msg="good path"): - path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), - PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] + path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, + end_node=graph.w_b.node_keypair.pubkey, + short_channel_id=graph.chan_ab.short_channel_id), + PathEdge(start_node=graph.w_b.node_keypair.pubkey, + end_node=graph.w_d.node_keypair.pubkey, + short_channel_id=graph.chan_bd.short_channel_id)] result, log = await graph.w_a.pay_invoice(pay_req, full_path=path) self.assertTrue(result) self.assertEqual( diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py index dc41284c0..fa99e6015 100644 --- a/electrum/tests/test_lnrouter.py +++ b/electrum/tests/test_lnrouter.py @@ -83,12 +83,14 @@ class Test_LNRouter(TestCaseForTestnet): cdb.add_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) - path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000) - self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')), - PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')), + path = path_finder.find_path_for_payment( + nodeA=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + nodeB=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', + invoice_amount_msat=100000) + self.assertEqual([PathEdge(start_node=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', end_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')), + PathEdge(start_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', end_node=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')), ], path) - start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' - route = path_finder.create_route_from_path(path, start_node) + route = path_finder.create_route_from_path(path) self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id) self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id) diff --git a/electrum/trampoline.py b/electrum/trampoline.py index 07c44bf80..7d9f104ac 100644 --- a/electrum/trampoline.py +++ b/electrum/trampoline.py @@ -61,11 +61,13 @@ def encode_routing_info(r_tags): def create_trampoline_route( + *, amount_msat:int, bucket_amount_msat:int, min_cltv_expiry:int, invoice_pubkey:bytes, invoice_features:int, + my_pubkey: bytes, trampoline_node_id, r_tags, t_tags, trampoline_fee_level, @@ -106,7 +108,8 @@ def create_trampoline_route( # trampoline hop route.append( TrampolineEdge( - node_id=trampoline_node_id, + start_node=my_pubkey, + end_node=trampoline_node_id, fee_base_msat=params['fee_base_msat'], fee_proportional_millionths=params['fee_proportional_millionths'], cltv_expiry_delta=params['cltv_expiry_delta'], @@ -114,7 +117,8 @@ def create_trampoline_route( if trampoline2: route.append( TrampolineEdge( - node_id=trampoline2, + start_node=trampoline_node_id, + end_node=trampoline2, fee_base_msat=params['fee_base_msat'], fee_proportional_millionths=params['fee_proportional_millionths'], cltv_expiry_delta=params['cltv_expiry_delta'], @@ -130,7 +134,8 @@ def create_trampoline_route( if route[-1].node_id != pubkey: route.append( TrampolineEdge( - node_id=pubkey, + start_node=route[-1].node_id, + end_node=pubkey, fee_base_msat=feebase, fee_proportional_millionths=feerate, cltv_expiry_delta=cltv, @@ -138,7 +143,8 @@ def create_trampoline_route( # Fake edge (not part of actual route, needed by calc_hops_data) route.append( TrampolineEdge( - node_id=invoice_pubkey, + start_node=route[-1].end_node, + end_node=invoice_pubkey, fee_base_msat=0, fee_proportional_millionths=0, cltv_expiry_delta=0, @@ -194,6 +200,7 @@ def create_trampoline_route_and_onion( min_cltv_expiry, invoice_pubkey, invoice_features, + my_pubkey: bytes, node_id, r_tags, t_tags, payment_hash, @@ -203,15 +210,17 @@ def create_trampoline_route_and_onion( trampoline2_list): # create route for the trampoline_onion trampoline_route = create_trampoline_route( - amount_msat, - bucket_amount_msat, - min_cltv_expiry, - invoice_pubkey, - invoice_features, - node_id, - r_tags, t_tags, - trampoline_fee_level, - trampoline2_list) + amount_msat=amount_msat, + bucket_amount_msat=bucket_amount_msat, + min_cltv_expiry=min_cltv_expiry, + my_pubkey=my_pubkey, + invoice_pubkey=invoice_pubkey, + invoice_features=invoice_features, + trampoline_node_id=node_id, + r_tags=r_tags, + t_tags=t_tags, + trampoline_fee_level=trampoline_fee_level, + trampoline2_list=trampoline2_list) # compute onion and fees final_cltv = local_height + min_cltv_expiry trampoline_onion, bucket_amount_with_fees, bucket_cltv = create_trampoline_onion(