Browse Source

lnworker: run create_route_for_payment end-to-end, incl private edges

We pass the private edges to lnrouter, and let it find routes end-to-end.
Previously the edge_cost heuristics didn't apply to the private edges
and we were just randomly picking one of the route hints and use that.
So e.g. cheaper private edges were not preferred, but they are now.

PathEdge now stores both start_node and end_node; not just end_node.
patch-4
SomberNight 4 years ago
parent
commit
750d8cfab5
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 83
      electrum/channel_db.py
  2. 180
      electrum/lnrouter.py
  3. 76
      electrum/lnworker.py
  4. 20
      electrum/tests/test_lnpeer.py
  5. 12
      electrum/tests/test_lnrouter.py
  6. 35
      electrum/trampoline.py

83
electrum/channel_db.py

@ -48,6 +48,7 @@ from .lnmsg import decode_msg
if TYPE_CHECKING: if TYPE_CHECKING:
from .network import Network from .network import Network
from .lnchannel import Channel from .lnchannel import Channel
from .lnrouter import RouteEdge
FLAG_DISABLE = 1 << 1 FLAG_DISABLE = 1 << 1
@ -81,6 +82,16 @@ class ChannelInfo(NamedTuple):
payload_dict = decode_msg(raw)[1] payload_dict = decode_msg(raw)[1]
return ChannelInfo.from_msg(payload_dict) return ChannelInfo.from_msg(payload_dict)
@staticmethod
def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
return ChannelInfo(
short_channel_id=route_edge.short_channel_id,
node1_id=node1_id,
node2_id=node2_id,
capacity_sat=None,
)
class Policy(NamedTuple): class Policy(NamedTuple):
key: bytes key: bytes
@ -113,6 +124,20 @@ class Policy(NamedTuple):
payload['start_node'] = key[8:] payload['start_node'] = key[8:]
return Policy.from_msg(payload) return Policy.from_msg(payload)
@staticmethod
def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
return Policy(
key=route_edge.short_channel_id + route_edge.start_node,
cltv_expiry_delta=route_edge.cltv_expiry_delta,
htlc_minimum_msat=0,
htlc_maximum_msat=None,
fee_base_msat=route_edge.fee_base_msat,
fee_proportional_millionths=route_edge.fee_proportional_millionths,
channel_flags=0,
message_flags=0,
timestamp=0,
)
def is_disabled(self): def is_disabled(self):
return self.channel_flags & FLAG_DISABLE return self.channel_flags & FLAG_DISABLE
@ -216,6 +241,8 @@ class CategorizedChannelUpdates(NamedTuple):
def get_mychannel_info(short_channel_id: ShortChannelID, def get_mychannel_info(short_channel_id: ShortChannelID,
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]: my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
chan = my_channels.get(short_channel_id) chan = my_channels.get(short_channel_id)
if not chan:
return
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs()) ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
return ci._replace(capacity_sat=chan.constraints.capacity) return ci._replace(capacity_sat=chan.constraints.capacity)
@ -724,8 +751,14 @@ class ChannelDB(SqlDB):
nchans_with_2p = len(self._chans_with_2_policies) nchans_with_2p = len(self._chans_with_2_policies)
return nchans_with_0p, nchans_with_1p, nchans_with_2p return nchans_with_0p, nchans_with_1p, nchans_with_2p
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, def get_policy_for_node(
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: self,
short_channel_id: bytes,
node_id: bytes,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Optional['Policy']:
channel_info = self.get_channel_info(short_channel_id) channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None: # publicly announced channel if channel_info is not None: # publicly announced channel
policy = self._policies.get((node_id, short_channel_id)) policy = self._policies.get((node_id, short_channel_id))
@ -737,28 +770,56 @@ class ChannelDB(SqlDB):
return Policy.from_msg(chan_upd_dict) return Policy.from_msg(chan_upd_dict)
# check if it's one of our own channels # check if it's one of our own channels
if my_channels: if my_channels:
return get_mychannel_policy(short_channel_id, node_id, my_channels) policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
if policy:
def get_channel_info(self, short_channel_id: ShortChannelID, *, return policy
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]: if private_route_edges:
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge:
return Policy.from_route_edge(route_edge)
def get_channel_info(
self,
short_channel_id: ShortChannelID,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Optional[ChannelInfo]:
ret = self._channels.get(short_channel_id) ret = self._channels.get(short_channel_id)
if ret: if ret:
return ret return ret
# check if it's one of our own channels # check if it's one of our own channels
if my_channels: if my_channels:
return get_mychannel_info(short_channel_id, my_channels) channel_info = get_mychannel_info(short_channel_id, my_channels)
if channel_info:
def get_channels_for_node(self, node_id: bytes, *, return channel_info
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]: if private_route_edges:
route_edge = private_route_edges.get(short_channel_id)
if route_edge:
return ChannelInfo.from_route_edge(route_edge)
def get_channels_for_node(
self,
node_id: bytes,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Set[bytes]:
"""Returns the set of short channel IDs where node_id is one of the channel participants.""" """Returns the set of short channel IDs where node_id is one of the channel participants."""
if not self.data_loaded.is_set(): if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!") raise Exception("channelDB data not loaded yet!")
relevant_channels = self._channels_for_node.get(node_id) or set() relevant_channels = self._channels_for_node.get(node_id) or set()
relevant_channels = set(relevant_channels) # copy relevant_channels = set(relevant_channels) # copy
# add our own channels # TODO maybe slow? # add our own channels # TODO maybe slow?
for chan in (my_channels.values() or []): if my_channels:
for chan in my_channels.values():
if node_id in (chan.node_id, chan.get_local_pubkey()): if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id) relevant_channels.add(chan.short_channel_id)
# add private channels # TODO maybe slow?
if private_route_edges:
for route_edge in private_route_edges.values():
if node_id in (route_edge.start_node, route_edge.end_node):
relevant_channels.add(route_edge.short_channel_id)
return relevant_channels return relevant_channels
def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *, def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,

180
electrum/lnrouter.py

@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
@attr.s(slots=True) @attr.s(slots=True)
class PathEdge: class PathEdge:
"""if you travel through short_channel_id, you will reach node_id""" start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
node_id = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val)) short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
@property
def node_id(self) -> bytes:
# legacy compat # TODO rm
return self.end_node
@attr.s @attr.s
class RouteEdge(PathEdge): class RouteEdge(PathEdge):
@ -73,12 +77,21 @@ class RouteEdge(PathEdge):
fee_proportional_millionths=self.fee_proportional_millionths) fee_proportional_millionths=self.fee_proportional_millionths)
@classmethod @classmethod
def from_channel_policy(cls, channel_policy: 'Policy', def from_channel_policy(
short_channel_id: bytes, end_node: bytes, *, cls,
node_info: Optional[NodeInfo]) -> 'RouteEdge': *,
channel_policy: 'Policy',
short_channel_id: bytes,
start_node: bytes,
end_node: bytes,
node_info: Optional[NodeInfo], # for end_node
) -> 'RouteEdge':
assert isinstance(short_channel_id, bytes) assert isinstance(short_channel_id, bytes)
assert type(start_node) is bytes
assert type(end_node) is bytes assert type(end_node) is bytes
return RouteEdge(node_id=end_node, return RouteEdge(
start_node=start_node,
end_node=end_node,
short_channel_id=ShortChannelID.normalize(short_channel_id), short_channel_id=ShortChannelID.normalize(short_channel_id),
fee_base_msat=channel_policy.fee_base_msat, fee_base_msat=channel_policy.fee_base_msat,
fee_proportional_millionths=channel_policy.fee_proportional_millionths, fee_proportional_millionths=channel_policy.fee_proportional_millionths,
@ -155,21 +168,37 @@ class LNPathFinder(Logger):
Logger.__init__(self) Logger.__init__(self)
self.channel_db = channel_db self.channel_db = channel_db
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, def _edge_cost(
payment_amt_msat: int, ignore_costs=False, is_mine=False, *, self,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]: *,
short_channel_id: bytes,
start_node: bytes,
end_node: bytes,
payment_amt_msat: int,
ignore_costs=False,
is_mine=False,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Tuple[float, int]:
"""Heuristic cost (distance metric) of going through a channel. """Heuristic cost (distance metric) of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat). Returns (heuristic_cost, fee_for_edge_msat).
""" """
channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels) if private_route_edges is None:
private_route_edges = {}
channel_info = self.channel_db.get_channel_info(
short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
if channel_info is None: if channel_info is None:
return float('inf'), 0 return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels) channel_policy = self.channel_db.get_policy_for_node(
short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
if channel_policy is None: if channel_policy is None:
return float('inf'), 0 return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure # channels that did not publish both policies often return temporary channel failure
if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \ channel_policy_backwards = self.channel_db.get_policy_for_node(
and not is_mine: short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
if (channel_policy_backwards is None
and not is_mine
and short_channel_id not in private_route_edges):
return float('inf'), 0 return float('inf'), 0
if channel_policy.is_disabled(): if channel_policy.is_disabled():
return float('inf'), 0 return float('inf'), 0
@ -181,8 +210,14 @@ class LNPathFinder(Logger):
if channel_policy.htlc_maximum_msat is not None and \ if channel_policy.htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy.htlc_maximum_msat: payment_amt_msat > channel_policy.htlc_maximum_msat:
return float('inf'), 0 # payment amount too large return float('inf'), 0 # payment amount too large
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge is None:
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node, route_edge = RouteEdge.from_channel_policy(
channel_policy=channel_policy,
short_channel_id=short_channel_id,
start_node=start_node,
end_node=end_node,
node_info=node_info) node_info=node_info)
if not route_edge.is_sane_to_use(payment_amt_msat): if not route_edge.is_sane_to_use(payment_amt_msat):
return float('inf'), 0 # thanks but no thanks return float('inf'), 0 # thanks but no thanks
@ -201,9 +236,16 @@ class LNPathFinder(Logger):
overall_cost = base_cost + fee_msat + cltv_cost overall_cost = base_cost + fee_msat + cltv_cost
return overall_cost, fee_msat return overall_cost, fee_msat
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, def get_distances(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
my_channels: Dict[ShortChannelID, 'Channel'] = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]: blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Dict[bytes, PathEdge]:
# note: we don't lock self.channel_db, so while the path finding runs, # note: we don't lock self.channel_db, so while the path finding runs,
# the underlying graph could potentially change... (not good but maybe ~OK?) # the underlying graph could potentially change... (not good but maybe ~OK?)
@ -226,11 +268,13 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue. # so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now: # so there are duplicates in the queue, that we discard now:
continue continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels): for edge_channel_id in self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
assert isinstance(edge_channel_id, bytes) assert isinstance(edge_channel_id, bytes)
if blacklist and edge_channel_id in blacklist: if blacklist and edge_channel_id in blacklist:
continue continue
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels) channel_info = self.channel_db.get_channel_info(
edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
is_mine = edge_channel_id in my_channels is_mine = edge_channel_id in my_channels
if is_mine: if is_mine:
@ -242,17 +286,20 @@ class LNPathFinder(Logger):
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True): if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
continue continue
edge_cost, fee_for_edge_msat = self._edge_cost( edge_cost, fee_for_edge_msat = self._edge_cost(
edge_channel_id, short_channel_id=edge_channel_id,
start_node=edge_startnode, start_node=edge_startnode,
end_node=edge_endnode, end_node=edge_endnode,
payment_amt_msat=amount_msat, payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA), ignore_costs=(edge_startnode == nodeA),
is_mine=is_mine, is_mine=is_mine,
my_channels=my_channels) my_channels=my_channels,
private_route_edges=private_route_edges)
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]: if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour distance_from_start[edge_startnode] = alt_dist_to_neighbour
prev_node[edge_startnode] = PathEdge(node_id=edge_endnode, prev_node[edge_startnode] = PathEdge(
start_node=edge_startnode,
end_node=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id)) short_channel_id=ShortChannelID(edge_channel_id))
amount_to_forward_msat = amount_msat + fee_for_edge_msat amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
@ -260,11 +307,16 @@ class LNPathFinder(Logger):
return prev_node return prev_node
@profiler @profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, def find_path_for_payment(
invoice_amount_msat: int, *, self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
my_channels: Dict[ShortChannelID, 'Channel'] = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) \ blacklist: Set[ShortChannelID] = None,
-> Optional[LNPaymentPath]: private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Optional[LNPaymentPath]:
"""Return a path from nodeA to nodeB.""" """Return a path from nodeA to nodeB."""
assert type(nodeA) is bytes assert type(nodeA) is bytes
assert type(nodeB) is bytes assert type(nodeB) is bytes
@ -272,7 +324,13 @@ class LNPathFinder(Logger):
if my_channels is None: if my_channels is None:
my_channels = {} my_channels = {}
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) prev_node = self.get_distances(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
if nodeA not in prev_node: if nodeA not in prev_node:
return None # no path found return None # no path found
@ -287,34 +345,66 @@ class LNPathFinder(Logger):
edge_startnode = edge.node_id edge_startnode = edge.node_id
return path return path
def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *, def create_route_from_path(
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute: self,
assert isinstance(from_node_id, bytes) path: Optional[LNPaymentPath],
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> LNPaymentRoute:
if path is None: if path is None:
raise Exception('cannot create route from None path') raise Exception('cannot create route from None path')
if private_route_edges is None:
private_route_edges = {}
route = [] route = []
prev_node_id = from_node_id prev_end_node = path[0].start_node
for edge in path: for path_edge in path:
node_id = edge.node_id short_channel_id = path_edge.short_channel_id
short_channel_id = edge.short_channel_id
_endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels) _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]): if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
if path_edge.start_node != prev_end_node:
raise LNPathInconsistent("edges do not chain together") raise LNPathInconsistent("edges do not chain together")
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id, route_edge = private_route_edges.get(short_channel_id, None)
node_id=prev_node_id, if route_edge is None:
channel_policy = self.channel_db.get_policy_for_node(
short_channel_id=short_channel_id,
node_id=path_edge.start_node,
my_channels=my_channels) my_channels=my_channels)
if channel_policy is None: if channel_policy is None:
raise NoChannelPolicy(short_channel_id) raise NoChannelPolicy(short_channel_id)
node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id) node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id, route_edge = RouteEdge.from_channel_policy(
node_info=node_info)) channel_policy=channel_policy,
prev_node_id = node_id short_channel_id=short_channel_id,
start_node=path_edge.start_node,
end_node=path_edge.end_node,
node_info=node_info)
route.append(route_edge)
prev_end_node = path_edge.end_node
return route return route
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, def find_route(
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None, self,
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]: *,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
path = None,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Optional[LNPaymentRoute]:
route = None
if not path: if not path:
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) path = self.find_path_for_payment(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
if path: if path:
return self.create_route_from_path(path, nodeA, my_channels=my_channels) route = self.create_route_from_path(
path, my_channels=my_channels, private_route_edges=private_route_edges)
return route

76
electrum/lnworker.py

@ -1320,6 +1320,7 @@ class LNWallet(LNWorker):
amount_msat=amount_msat, amount_msat=amount_msat,
bucket_amount_msat=amount_msat, bucket_amount_msat=amount_msat,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey, invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features, invoice_features=invoice_features,
node_id=chan.node_id, node_id=chan.node_id,
@ -1336,7 +1337,8 @@ class LNWallet(LNWorker):
continue continue
route = [ route = [
RouteEdge( RouteEdge(
node_id=chan.node_id, start_node=self.node_keypair.pubkey,
end_node=chan.node_id,
short_channel_id=chan.short_channel_id, short_channel_id=chan.short_channel_id,
fee_base_msat=0, fee_base_msat=0,
fee_proportional_millionths=0, fee_proportional_millionths=0,
@ -1383,6 +1385,7 @@ class LNWallet(LNWorker):
amount_msat=amount_msat, amount_msat=amount_msat,
bucket_amount_msat=bucket_amount_msat, bucket_amount_msat=bucket_amount_msat,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey, invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features, invoice_features=invoice_features,
node_id=node_id, node_id=node_id,
@ -1404,7 +1407,8 @@ class LNWallet(LNWorker):
trampoline_fee -= delta_fee trampoline_fee -= delta_fee
route = [ route = [
RouteEdge( RouteEdge(
node_id=node_id, start_node=self.node_keypair.pubkey,
end_node=node_id,
short_channel_id=chan.short_channel_id, short_channel_id=chan.short_channel_id,
fee_base_msat=0, fee_base_msat=0,
fee_proportional_millionths=0, fee_proportional_millionths=0,
@ -1447,77 +1451,65 @@ class LNWallet(LNWorker):
full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]: full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
channels = [outgoing_channel] if outgoing_channel else list(self.channels.values()) channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
route = None
scid_to_my_channels = { scid_to_my_channels = {
chan.short_channel_id: chan for chan in channels chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None if chan.short_channel_id is not None
} }
blacklist = self.network.channel_blacklist.get_current_list() blacklist = self.network.channel_blacklist.get_current_list()
# first try with routing hints, then without # Collect all private edges from route hints.
for private_path in r_tags + [[]]: # Note: if some route hints are multiple edges long, and these paths cross each other,
private_route = [] # we allow our path finding to cross the paths; i.e. the route hints are not isolated.
amount_for_node = amount_msat private_route_edges = {} # type: Dict[ShortChannelID, RouteEdge]
path = full_path for private_path in r_tags:
if len(private_path) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue
if len(private_path) == 0:
border_node_pubkey = invoice_pubkey
else:
border_node_pubkey = private_path[0][0]
# we need to shift the node pubkey by one towards the destination: # we need to shift the node pubkey by one towards the destination:
private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey] private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey]
private_path_rest = [edge[1:] for edge in private_path] private_path_rest = [edge[1:] for edge in private_path]
prev_node_id = border_node_pubkey start_node = private_path[0][0]
for node_pubkey, edge_rest in zip(private_path_nodes, private_path_rest): for end_node, edge_rest in zip(private_path_nodes, private_path_rest):
short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
short_channel_id = ShortChannelID(short_channel_id) short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence, # if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure # as it is likely from a previous failure
channel_policy = self.channel_db.get_policy_for_node( channel_policy = self.channel_db.get_policy_for_node(
short_channel_id=short_channel_id, short_channel_id=short_channel_id,
node_id=prev_node_id, node_id=start_node,
my_channels=scid_to_my_channels) my_channels=scid_to_my_channels)
if channel_policy: if channel_policy:
fee_base_msat = channel_policy.fee_base_msat fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths fee_proportional_millionths = channel_policy.fee_proportional_millionths
cltv_expiry_delta = channel_policy.cltv_expiry_delta cltv_expiry_delta = channel_policy.cltv_expiry_delta
node_info = self.channel_db.get_node_info_for_node_id(node_id=node_pubkey) node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
private_route.append( route_edge = RouteEdge(
RouteEdge( start_node=start_node,
node_id=node_pubkey, end_node=end_node,
short_channel_id=short_channel_id, short_channel_id=short_channel_id,
fee_base_msat=fee_base_msat, fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths, fee_proportional_millionths=fee_proportional_millionths,
cltv_expiry_delta=cltv_expiry_delta, cltv_expiry_delta=cltv_expiry_delta,
node_features=node_info.features if node_info else 0)) node_features=node_info.features if node_info else 0)
prev_node_id = node_pubkey if route_edge.short_channel_id not in blacklist:
for edge in private_route[::-1]: private_route_edges[route_edge.short_channel_id] = route_edge
amount_for_node += edge.fee_for_edge(amount_for_node) start_node = end_node
if full_path: # now find a route, end to end: between us and the recipient
# user pre-selected path. check that end of given path coincides with private_route:
if [edge.short_channel_id for edge in full_path[-len(private_path):]] != [edge[1] for edge in private_path]:
continue
path = full_path[:-len(private_path)]
if any(edge.short_channel_id in blacklist for edge in private_route):
continue
try: try:
route = self.network.path_finder.find_route( route = self.network.path_finder.find_route(
self.node_keypair.pubkey, border_node_pubkey, amount_for_node, nodeA=self.node_keypair.pubkey,
path=path, my_channels=scid_to_my_channels, blacklist=blacklist) nodeB=invoice_pubkey,
except NoChannelPolicy: invoice_amount_msat=amount_msat,
continue path=full_path,
my_channels=scid_to_my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
except NoChannelPolicy as e:
raise NoPathFound() from e
if not route: if not route:
continue raise NoPathFound()
route = route + private_route
# test sanity # test sanity
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
self.logger.info(f"rejecting insane route {route}") self.logger.info(f"rejecting insane route {route}")
continue
break
else:
raise NoPathFound() raise NoPathFound()
assert len(route) > 0 assert len(route) > 0
if route[-1].node_id != invoice_pubkey: if route[-1].end_node != invoice_pubkey:
raise LNPathInconsistent("last node_id != invoice pubkey") raise LNPathInconsistent("last node_id != invoice pubkey")
# add features from invoice # add features from invoice
route[-1].node_features |= invoice_features route[-1].node_features |= invoice_features

20
electrum/tests/test_lnpeer.py

@ -600,17 +600,27 @@ class TestPeer(ElectrumTestCase):
peers = graph.all_peers() peers = graph.all_peers()
async def pay(pay_req): async def pay(pay_req):
with self.subTest(msg="bad path: edges do not chain together"): with self.subTest(msg="bad path: edges do not chain together"):
path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] end_node=graph.w_c.node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(start_node=graph.w_b.node_keypair.pubkey,
end_node=graph.w_d.node_keypair.pubkey,
short_channel_id=graph.chan_bd.short_channel_id)]
with self.assertRaises(LNPathInconsistent): with self.assertRaises(LNPathInconsistent):
await graph.w_a.pay_invoice(pay_req, full_path=path) await graph.w_a.pay_invoice(pay_req, full_path=path)
with self.subTest(msg="bad path: last node id differs from invoice pubkey"): with self.subTest(msg="bad path: last node id differs from invoice pubkey"):
path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)] path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
end_node=graph.w_b.node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id)]
with self.assertRaises(LNPathInconsistent): with self.assertRaises(LNPathInconsistent):
await graph.w_a.pay_invoice(pay_req, full_path=path) await graph.w_a.pay_invoice(pay_req, full_path=path)
with self.subTest(msg="good path"): with self.subTest(msg="good path"):
path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey,
PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] end_node=graph.w_b.node_keypair.pubkey,
short_channel_id=graph.chan_ab.short_channel_id),
PathEdge(start_node=graph.w_b.node_keypair.pubkey,
end_node=graph.w_d.node_keypair.pubkey,
short_channel_id=graph.chan_bd.short_channel_id)]
result, log = await graph.w_a.pay_invoice(pay_req, full_path=path) result, log = await graph.w_a.pay_invoice(pay_req, full_path=path)
self.assertTrue(result) self.assertTrue(result)
self.assertEqual( self.assertEqual(

12
electrum/tests/test_lnrouter.py

@ -83,12 +83,14 @@ class Test_LNRouter(TestCaseForTestnet):
cdb.add_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000) path = path_finder.find_path_for_payment(
self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')), nodeA=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')), nodeB=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
invoice_amount_msat=100000)
self.assertEqual([PathEdge(start_node=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', end_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')),
PathEdge(start_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', end_node=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')),
], path) ], path)
start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' route = path_finder.create_route_from_path(path)
route = path_finder.create_route_from_path(path, start_node)
self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id) self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id)
self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id) self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id)

35
electrum/trampoline.py

@ -61,11 +61,13 @@ def encode_routing_info(r_tags):
def create_trampoline_route( def create_trampoline_route(
*,
amount_msat:int, amount_msat:int,
bucket_amount_msat:int, bucket_amount_msat:int,
min_cltv_expiry:int, min_cltv_expiry:int,
invoice_pubkey:bytes, invoice_pubkey:bytes,
invoice_features:int, invoice_features:int,
my_pubkey: bytes,
trampoline_node_id, trampoline_node_id,
r_tags, t_tags, r_tags, t_tags,
trampoline_fee_level, trampoline_fee_level,
@ -106,7 +108,8 @@ def create_trampoline_route(
# trampoline hop # trampoline hop
route.append( route.append(
TrampolineEdge( TrampolineEdge(
node_id=trampoline_node_id, start_node=my_pubkey,
end_node=trampoline_node_id,
fee_base_msat=params['fee_base_msat'], fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'], fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'], cltv_expiry_delta=params['cltv_expiry_delta'],
@ -114,7 +117,8 @@ def create_trampoline_route(
if trampoline2: if trampoline2:
route.append( route.append(
TrampolineEdge( TrampolineEdge(
node_id=trampoline2, start_node=trampoline_node_id,
end_node=trampoline2,
fee_base_msat=params['fee_base_msat'], fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'], fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'], cltv_expiry_delta=params['cltv_expiry_delta'],
@ -130,7 +134,8 @@ def create_trampoline_route(
if route[-1].node_id != pubkey: if route[-1].node_id != pubkey:
route.append( route.append(
TrampolineEdge( TrampolineEdge(
node_id=pubkey, start_node=route[-1].node_id,
end_node=pubkey,
fee_base_msat=feebase, fee_base_msat=feebase,
fee_proportional_millionths=feerate, fee_proportional_millionths=feerate,
cltv_expiry_delta=cltv, cltv_expiry_delta=cltv,
@ -138,7 +143,8 @@ def create_trampoline_route(
# Fake edge (not part of actual route, needed by calc_hops_data) # Fake edge (not part of actual route, needed by calc_hops_data)
route.append( route.append(
TrampolineEdge( TrampolineEdge(
node_id=invoice_pubkey, start_node=route[-1].end_node,
end_node=invoice_pubkey,
fee_base_msat=0, fee_base_msat=0,
fee_proportional_millionths=0, fee_proportional_millionths=0,
cltv_expiry_delta=0, cltv_expiry_delta=0,
@ -194,6 +200,7 @@ def create_trampoline_route_and_onion(
min_cltv_expiry, min_cltv_expiry,
invoice_pubkey, invoice_pubkey,
invoice_features, invoice_features,
my_pubkey: bytes,
node_id, node_id,
r_tags, t_tags, r_tags, t_tags,
payment_hash, payment_hash,
@ -203,15 +210,17 @@ def create_trampoline_route_and_onion(
trampoline2_list): trampoline2_list):
# create route for the trampoline_onion # create route for the trampoline_onion
trampoline_route = create_trampoline_route( trampoline_route = create_trampoline_route(
amount_msat, amount_msat=amount_msat,
bucket_amount_msat, bucket_amount_msat=bucket_amount_msat,
min_cltv_expiry, min_cltv_expiry=min_cltv_expiry,
invoice_pubkey, my_pubkey=my_pubkey,
invoice_features, invoice_pubkey=invoice_pubkey,
node_id, invoice_features=invoice_features,
r_tags, t_tags, trampoline_node_id=node_id,
trampoline_fee_level, r_tags=r_tags,
trampoline2_list) t_tags=t_tags,
trampoline_fee_level=trampoline_fee_level,
trampoline2_list=trampoline2_list)
# compute onion and fees # compute onion and fees
final_cltv = local_height + min_cltv_expiry final_cltv = local_height + min_cltv_expiry
trampoline_onion, bucket_amount_with_fees, bucket_cltv = create_trampoline_onion( trampoline_onion, bucket_amount_with_fees, bucket_cltv = create_trampoline_onion(

Loading…
Cancel
Save