Browse Source

lnworker: run create_route_for_payment end-to-end, incl private edges

We pass the private edges to lnrouter, and let it find routes end-to-end.
Previously the edge_cost heuristics didn't apply to the private edges
and we were just randomly picking one of the route hints and use that.
So e.g. cheaper private edges were not preferred, but they are now.

PathEdge now stores both start_node and end_node; not just end_node.
patch-4
SomberNight 4 years ago
parent
commit
750d8cfab5
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 87
      electrum/channel_db.py
  2. 206
      electrum/lnrouter.py
  3. 122
      electrum/lnworker.py
  4. 20
      electrum/tests/test_lnpeer.py
  5. 12
      electrum/tests/test_lnrouter.py
  6. 35
      electrum/trampoline.py

87
electrum/channel_db.py

@ -48,6 +48,7 @@ from .lnmsg import decode_msg
if TYPE_CHECKING:
from .network import Network
from .lnchannel import Channel
from .lnrouter import RouteEdge
FLAG_DISABLE = 1 << 1
@ -81,6 +82,16 @@ class ChannelInfo(NamedTuple):
payload_dict = decode_msg(raw)[1]
return ChannelInfo.from_msg(payload_dict)
@staticmethod
def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
return ChannelInfo(
short_channel_id=route_edge.short_channel_id,
node1_id=node1_id,
node2_id=node2_id,
capacity_sat=None,
)
class Policy(NamedTuple):
key: bytes
@ -113,6 +124,20 @@ class Policy(NamedTuple):
payload['start_node'] = key[8:]
return Policy.from_msg(payload)
@staticmethod
def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
return Policy(
key=route_edge.short_channel_id + route_edge.start_node,
cltv_expiry_delta=route_edge.cltv_expiry_delta,
htlc_minimum_msat=0,
htlc_maximum_msat=None,
fee_base_msat=route_edge.fee_base_msat,
fee_proportional_millionths=route_edge.fee_proportional_millionths,
channel_flags=0,
message_flags=0,
timestamp=0,
)
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
@ -216,6 +241,8 @@ class CategorizedChannelUpdates(NamedTuple):
def get_mychannel_info(short_channel_id: ShortChannelID,
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
chan = my_channels.get(short_channel_id)
if not chan:
return
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
return ci._replace(capacity_sat=chan.constraints.capacity)
@ -724,8 +751,14 @@ class ChannelDB(SqlDB):
nchans_with_2p = len(self._chans_with_2_policies)
return nchans_with_0p, nchans_with_1p, nchans_with_2p
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
def get_policy_for_node(
self,
short_channel_id: bytes,
node_id: bytes,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Optional['Policy']:
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None: # publicly announced channel
policy = self._policies.get((node_id, short_channel_id))
@ -737,28 +770,56 @@ class ChannelDB(SqlDB):
return Policy.from_msg(chan_upd_dict)
# check if it's one of our own channels
if my_channels:
return get_mychannel_policy(short_channel_id, node_id, my_channels)
def get_channel_info(self, short_channel_id: ShortChannelID, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
if policy:
return policy
if private_route_edges:
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge:
return Policy.from_route_edge(route_edge)
def get_channel_info(
self,
short_channel_id: ShortChannelID,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Optional[ChannelInfo]:
ret = self._channels.get(short_channel_id)
if ret:
return ret
# check if it's one of our own channels
if my_channels:
return get_mychannel_info(short_channel_id, my_channels)
def get_channels_for_node(self, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
channel_info = get_mychannel_info(short_channel_id, my_channels)
if channel_info:
return channel_info
if private_route_edges:
route_edge = private_route_edges.get(short_channel_id)
if route_edge:
return ChannelInfo.from_route_edge(route_edge)
def get_channels_for_node(
self,
node_id: bytes,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Set[bytes]:
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!")
relevant_channels = self._channels_for_node.get(node_id) or set()
relevant_channels = set(relevant_channels) # copy
# add our own channels # TODO maybe slow?
for chan in (my_channels.values() or []):
if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id)
if my_channels:
for chan in my_channels.values():
if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id)
# add private channels # TODO maybe slow?
if private_route_edges:
for route_edge in private_route_edges.values():
if node_id in (route_edge.start_node, route_edge.end_node):
relevant_channels.add(route_edge.short_channel_id)
return relevant_channels
def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,

206
electrum/lnrouter.py

@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
@attr.s(slots=True)
class PathEdge:
"""if you travel through short_channel_id, you will reach node_id"""
node_id = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
@property
def node_id(self) -> bytes:
# legacy compat # TODO rm
return self.end_node
@attr.s
class RouteEdge(PathEdge):
@ -73,17 +77,26 @@ class RouteEdge(PathEdge):
fee_proportional_millionths=self.fee_proportional_millionths)
@classmethod
def from_channel_policy(cls, channel_policy: 'Policy',
short_channel_id: bytes, end_node: bytes, *,
node_info: Optional[NodeInfo]) -> 'RouteEdge':
def from_channel_policy(
cls,
*,
channel_policy: 'Policy',
short_channel_id: bytes,
start_node: bytes,
end_node: bytes,
node_info: Optional[NodeInfo], # for end_node
) -> 'RouteEdge':
assert isinstance(short_channel_id, bytes)
assert type(start_node) is bytes
assert type(end_node) is bytes
return RouteEdge(node_id=end_node,
short_channel_id=ShortChannelID.normalize(short_channel_id),
fee_base_msat=channel_policy.fee_base_msat,
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
node_features=node_info.features if node_info else 0)
return RouteEdge(
start_node=start_node,
end_node=end_node,
short_channel_id=ShortChannelID.normalize(short_channel_id),
fee_base_msat=channel_policy.fee_base_msat,
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
node_features=node_info.features if node_info else 0)
def is_sane_to_use(self, amount_msat: int) -> bool:
# TODO revise ad-hoc heuristics
@ -155,21 +168,37 @@ class LNPathFinder(Logger):
Logger.__init__(self)
self.channel_db = channel_db
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
def _edge_cost(
self,
*,
short_channel_id: bytes,
start_node: bytes,
end_node: bytes,
payment_amt_msat: int,
ignore_costs=False,
is_mine=False,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Tuple[float, int]:
"""Heuristic cost (distance metric) of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat).
"""
channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
if private_route_edges is None:
private_route_edges = {}
channel_info = self.channel_db.get_channel_info(
short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
if channel_info is None:
return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
channel_policy = self.channel_db.get_policy_for_node(
short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
if channel_policy is None:
return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure
if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
and not is_mine:
channel_policy_backwards = self.channel_db.get_policy_for_node(
short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
if (channel_policy_backwards is None
and not is_mine
and short_channel_id not in private_route_edges):
return float('inf'), 0
if channel_policy.is_disabled():
return float('inf'), 0
@ -181,9 +210,15 @@ class LNPathFinder(Logger):
if channel_policy.htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy.htlc_maximum_msat:
return float('inf'), 0 # payment amount too large
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node,
node_info=node_info)
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge is None:
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
route_edge = RouteEdge.from_channel_policy(
channel_policy=channel_policy,
short_channel_id=short_channel_id,
start_node=start_node,
end_node=end_node,
node_info=node_info)
if not route_edge.is_sane_to_use(payment_amt_msat):
return float('inf'), 0 # thanks but no thanks
@ -201,9 +236,16 @@ class LNPathFinder(Logger):
overall_cost = base_cost + fee_msat + cltv_cost
return overall_cost, fee_msat
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
def get_distances(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Dict[bytes, PathEdge]:
# note: we don't lock self.channel_db, so while the path finding runs,
# the underlying graph could potentially change... (not good but maybe ~OK?)
@ -226,11 +268,13 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
for edge_channel_id in self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
assert isinstance(edge_channel_id, bytes)
if blacklist and edge_channel_id in blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
channel_info = self.channel_db.get_channel_info(
edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
is_mine = edge_channel_id in my_channels
if is_mine:
@ -242,29 +286,37 @@ class LNPathFinder(Logger):
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
continue
edge_cost, fee_for_edge_msat = self._edge_cost(
edge_channel_id,
short_channel_id=edge_channel_id,
start_node=edge_startnode,
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA),
is_mine=is_mine,
my_channels=my_channels)
my_channels=my_channels,
private_route_edges=private_route_edges)
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
prev_node[edge_startnode] = PathEdge(node_id=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id))
prev_node[edge_startnode] = PathEdge(
start_node=edge_startnode,
end_node=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id))
amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
return prev_node
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) \
-> Optional[LNPaymentPath]:
def find_path_for_payment(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Optional[LNPaymentPath]:
"""Return a path from nodeA to nodeB."""
assert type(nodeA) is bytes
assert type(nodeB) is bytes
@ -272,7 +324,13 @@ class LNPathFinder(Logger):
if my_channels is None:
my_channels = {}
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
prev_node = self.get_distances(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
if nodeA not in prev_node:
return None # no path found
@ -287,34 +345,66 @@ class LNPathFinder(Logger):
edge_startnode = edge.node_id
return path
def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
assert isinstance(from_node_id, bytes)
def create_route_from_path(
self,
path: Optional[LNPaymentPath],
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> LNPaymentRoute:
if path is None:
raise Exception('cannot create route from None path')
if private_route_edges is None:
private_route_edges = {}
route = []
prev_node_id = from_node_id
for edge in path:
node_id = edge.node_id
short_channel_id = edge.short_channel_id
prev_end_node = path[0].start_node
for path_edge in path:
short_channel_id = path_edge.short_channel_id
_endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]):
if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
if path_edge.start_node != prev_end_node:
raise LNPathInconsistent("edges do not chain together")
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=my_channels)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id,
node_info=node_info))
prev_node_id = node_id
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge is None:
channel_policy = self.channel_db.get_policy_for_node(
short_channel_id=short_channel_id,
node_id=path_edge.start_node,
my_channels=my_channels)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
route_edge = RouteEdge.from_channel_policy(
channel_policy=channel_policy,
short_channel_id=short_channel_id,
start_node=path_edge.start_node,
end_node=path_edge.end_node,
node_info=node_info)
route.append(route_edge)
prev_end_node = path_edge.end_node
return route
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
def find_route(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
path = None,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Optional[LNPaymentRoute]:
route = None
if not path:
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
path = self.find_path_for_payment(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
if path:
return self.create_route_from_path(path, nodeA, my_channels=my_channels)
route = self.create_route_from_path(
path, my_channels=my_channels, private_route_edges=private_route_edges)
return route

122
electrum/lnworker.py

@ -1320,6 +1320,7 @@ class LNWallet(LNWorker):
amount_msat=amount_msat,
bucket_amount_msat=amount_msat,
min_cltv_expiry=min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features,
node_id=chan.node_id,
@ -1336,7 +1337,8 @@ class LNWallet(LNWorker):
continue
route = [
RouteEdge(
node_id=chan.node_id,
start_node=self.node_keypair.pubkey,
end_node=chan.node_id,
short_channel_id=chan.short_channel_id,
fee_base_msat=0,
fee_proportional_millionths=0,
@ -1383,6 +1385,7 @@ class LNWallet(LNWorker):
amount_msat=amount_msat,
bucket_amount_msat=bucket_amount_msat,
min_cltv_expiry=min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features,
node_id=node_id,
@ -1404,7 +1407,8 @@ class LNWallet(LNWorker):
trampoline_fee -= delta_fee
route = [
RouteEdge(
node_id=node_id,
start_node=self.node_keypair.pubkey,
end_node=node_id,
short_channel_id=chan.short_channel_id,
fee_base_msat=0,
fee_proportional_millionths=0,
@ -1447,77 +1451,65 @@ class LNWallet(LNWorker):
full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
route = None
scid_to_my_channels = {
chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None
}
blacklist = self.network.channel_blacklist.get_current_list()
# first try with routing hints, then without
for private_path in r_tags + [[]]:
private_route = []
amount_for_node = amount_msat
path = full_path
if len(private_path) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue
if len(private_path) == 0:
border_node_pubkey = invoice_pubkey
else:
border_node_pubkey = private_path[0][0]
# we need to shift the node pubkey by one towards the destination:
private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
private_path_rest = [edge[1:] for edge in private_path]
prev_node_id = border_node_pubkey
for node_pubkey, edge_rest in zip(private_path_nodes, private_path_rest):
short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure
channel_policy = self.channel_db.get_policy_for_node(
# Collect all private edges from route hints.
# Note: if some route hints are multiple edges long, and these paths cross each other,
# we allow our path finding to cross the paths; i.e. the route hints are not isolated.
private_route_edges = {} # type: Dict[ShortChannelID, RouteEdge]
for private_path in r_tags:
# we need to shift the node pubkey by one towards the destination:
private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
private_path_rest = [edge[1:] for edge in private_path]
start_node = private_path[0][0]
for end_node, edge_rest in zip(private_path_nodes, private_path_rest):
short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure
channel_policy = self.channel_db.get_policy_for_node(
short_channel_id=short_channel_id,
node_id=start_node,
my_channels=scid_to_my_channels)
if channel_policy:
fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths
cltv_expiry_delta = channel_policy.cltv_expiry_delta
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
route_edge = RouteEdge(
start_node=start_node,
end_node=end_node,
short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=scid_to_my_channels)
if channel_policy:
fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths
cltv_expiry_delta = channel_policy.cltv_expiry_delta
node_info = self.channel_db.get_node_info_for_node_id(node_id=node_pubkey)
private_route.append(
RouteEdge(
node_id=node_pubkey,
short_channel_id=short_channel_id,
fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths,
cltv_expiry_delta=cltv_expiry_delta,
node_features=node_info.features if node_info else 0))
prev_node_id = node_pubkey
for edge in private_route[::-1]:
amount_for_node += edge.fee_for_edge(amount_for_node)
if full_path:
# user pre-selected path. check that end of given path coincides with private_route:
if [edge.short_channel_id for edge in full_path[-len(private_path):]] != [edge[1] for edge in private_path]:
continue
path = full_path[:-len(private_path)]
if any(edge.short_channel_id in blacklist for edge in private_route):
continue
try:
route = self.network.path_finder.find_route(
self.node_keypair.pubkey, border_node_pubkey, amount_for_node,
path=path, my_channels=scid_to_my_channels, blacklist=blacklist)
except NoChannelPolicy:
continue
if not route:
continue
route = route + private_route
# test sanity
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
self.logger.info(f"rejecting insane route {route}")
continue
break
else:
fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths,
cltv_expiry_delta=cltv_expiry_delta,
node_features=node_info.features if node_info else 0)
if route_edge.short_channel_id not in blacklist:
private_route_edges[route_edge.short_channel_id] = route_edge
start_node = end_node
# now find a route, end to end: between us and the recipient
try:
route = self.network.path_finder.find_route(
nodeA=self.node_keypair.pubkey,
nodeB=invoice_pubkey,
invoice_amount_msat=amount_msat,
path=full_path,
my_channels=scid_to_my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
except NoChannelPolicy as e:
raise NoPathFound() from e
if not route:
raise NoPathFound()
# test sanity
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
self.logger.info(f"rejecting insane route {route}")
raise NoPathFound()
assert len(route) > 0
if route[-1].node_id != invoice_pubkey:
if route[-1].end_node != invoice_pubkey:
raise LNPathInconsistent("last node_id != invoice pubkey")
# add features from invoice
route[-1].node_features |= invoice_features

20
electrum/tests/test_lnpeer.py

@ -600,17 +600,27 @@ class TestPeer(ElectrumTestCase):
peers = graph.all_peers()
async def pay(pay_req):
with self.subTest(msg="bad path: edges do not chain together"):
path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)]
path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
end_node=graph.w_c.node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(start_node=graph.w_b.node_keypair.pubkey,
end_node=graph.w_d.node_keypair.pubkey,
short_channel_id=graph.chan_bd.short_channel_id)]
with self.assertRaises(LNPathInconsistent):
await graph.w_a.pay_invoice(pay_req, full_path=path)
with self.subTest(msg="bad path: last node id differs from invoice pubkey"):
path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)]
path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
end_node=graph.w_b.node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id)]
with self.assertRaises(LNPathInconsistent):
await graph.w_a.pay_invoice(pay_req, full_path=path)
with self.subTest(msg="good path"):
path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)]
path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
end_node=graph.w_b.node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(start_node=graph.w_b.node_keypair.pubkey,
end_node=graph.w_d.node_keypair.pubkey,
short_channel_id=graph.chan_bd.short_channel_id)]
result, log = await graph.w_a.pay_invoice(pay_req, full_path=path)
self.assertTrue(result)
self.assertEqual(

12
electrum/tests/test_lnrouter.py

@ -83,12 +83,14 @@ class Test_LNRouter(TestCaseForTestnet):
cdb.add_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)
self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')),
PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')),
path = path_finder.find_path_for_payment(
nodeA=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
nodeB=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
invoice_amount_msat=100000)
self.assertEqual([PathEdge(start_node=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', end_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')),
PathEdge(start_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', end_node=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')),
], path)
start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
route = path_finder.create_route_from_path(path, start_node)
route = path_finder.create_route_from_path(path)
self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id)
self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id)

35
electrum/trampoline.py

@ -61,11 +61,13 @@ def encode_routing_info(r_tags):
def create_trampoline_route(
*,
amount_msat:int,
bucket_amount_msat:int,
min_cltv_expiry:int,
invoice_pubkey:bytes,
invoice_features:int,
my_pubkey: bytes,
trampoline_node_id,
r_tags, t_tags,
trampoline_fee_level,
@ -106,7 +108,8 @@ def create_trampoline_route(
# trampoline hop
route.append(
TrampolineEdge(
node_id=trampoline_node_id,
start_node=my_pubkey,
end_node=trampoline_node_id,
fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'],
@ -114,7 +117,8 @@ def create_trampoline_route(
if trampoline2:
route.append(
TrampolineEdge(
node_id=trampoline2,
start_node=trampoline_node_id,
end_node=trampoline2,
fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'],
@ -130,7 +134,8 @@ def create_trampoline_route(
if route[-1].node_id != pubkey:
route.append(
TrampolineEdge(
node_id=pubkey,
start_node=route[-1].node_id,
end_node=pubkey,
fee_base_msat=feebase,
fee_proportional_millionths=feerate,
cltv_expiry_delta=cltv,
@ -138,7 +143,8 @@ def create_trampoline_route(
# Fake edge (not part of actual route, needed by calc_hops_data)
route.append(
TrampolineEdge(
node_id=invoice_pubkey,
start_node=route[-1].end_node,
end_node=invoice_pubkey,
fee_base_msat=0,
fee_proportional_millionths=0,
cltv_expiry_delta=0,
@ -194,6 +200,7 @@ def create_trampoline_route_and_onion(
min_cltv_expiry,
invoice_pubkey,
invoice_features,
my_pubkey: bytes,
node_id,
r_tags, t_tags,
payment_hash,
@ -203,15 +210,17 @@ def create_trampoline_route_and_onion(
trampoline2_list):
# create route for the trampoline_onion
trampoline_route = create_trampoline_route(
amount_msat,
bucket_amount_msat,
min_cltv_expiry,
invoice_pubkey,
invoice_features,
node_id,
r_tags, t_tags,
trampoline_fee_level,
trampoline2_list)
amount_msat=amount_msat,
bucket_amount_msat=bucket_amount_msat,
min_cltv_expiry=min_cltv_expiry,
my_pubkey=my_pubkey,
invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features,
trampoline_node_id=node_id,
r_tags=r_tags,
t_tags=t_tags,
trampoline_fee_level=trampoline_fee_level,
trampoline2_list=trampoline2_list)
# compute onion and fees
final_cltv = local_height + min_cltv_expiry
trampoline_onion, bucket_amount_with_fees, bucket_cltv = create_trampoline_onion(

Loading…
Cancel
Save