From 97393d05aa55a2095b7e6323aa7d4fc5a6014723 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Mon, 8 Oct 2018 20:36:46 +0200 Subject: [PATCH] use 'r' field in invoice when making payments (routing hints) --- electrum/lnbase.py | 7 +++---- electrum/lnrouter.py | 25 ++++++++++++++----------- electrum/lnworker.py | 39 ++++++++++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 22 deletions(-) diff --git a/electrum/lnbase.py b/electrum/lnbase.py index 2dd6577d9..3e75a0257 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -1018,19 +1018,18 @@ class Peer(PrintError): await self.receive_commitment(chan) self.revoke(chan) - async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry): + async def pay(self, route: List[RouteEdge], chan, amount_msat, payment_hash, min_final_cltv_expiry): assert chan.get_state() == "OPEN", chan.get_state() assert amount_msat > 0, "amount_msat is not greater zero" height = self.network.get_local_height() - route = self.network.path_finder.create_route_from_path(path, self.lnworker.node_keypair.pubkey) hops_data = [] - sum_of_deltas = sum(route_edge.channel_policy.cltv_expiry_delta for route_edge in route[1:]) + sum_of_deltas = sum(route_edge.cltv_expiry_delta for route_edge in route[1:]) total_fee = 0 final_cltv_expiry_without_deltas = (height + min_final_cltv_expiry) final_cltv_expiry_with_deltas = final_cltv_expiry_without_deltas + sum_of_deltas for idx, route_edge in enumerate(route[1:]): hops_data += [OnionHopsDataSingle(OnionPerHop(route_edge.short_channel_id, amount_msat.to_bytes(8, "big"), final_cltv_expiry_without_deltas.to_bytes(4, "big")))] - total_fee += route_edge.channel_policy.fee_base_msat + ( amount_msat * route_edge.channel_policy.fee_proportional_millionths // 1000000 ) + total_fee += route_edge.fee_base_msat + ( amount_msat * route_edge.fee_proportional_millionths // 1000000 ) associated_data = payment_hash secret_key = os.urandom(32) hops_data += [OnionHopsDataSingle(OnionPerHop(b"\x00"*8, amount_msat.to_bytes(8, "big"), (final_cltv_expiry_without_deltas).to_bytes(4, "big")))] diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 04611e4ab..e9e72f292 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -28,7 +28,7 @@ import os import json import threading from collections import namedtuple, defaultdict -from typing import Sequence, Union, Tuple, Optional +from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple import binascii import base64 import asyncio @@ -478,14 +478,13 @@ class ChannelDB(JsonDB): direction)) -class RouteEdge: - - def __init__(self, node_id: bytes, short_channel_id: bytes, - channel_policy: ChannelInfoDirectedPolicy): - # "if you travel through short_channel_id, you will reach node_id" - self.node_id = node_id - self.short_channel_id = short_channel_id - self.channel_policy = channel_policy +class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), + ('short_channel_id', bytes), + ('fee_base_msat', int), + ('fee_proportional_millionths', int), + ('cltv_expiry_delta', int)])): + """if you travel through short_channel_id, you will reach node_id""" + pass class LNPathFinder(PrintError): @@ -578,7 +577,7 @@ class LNPathFinder(PrintError): path.reverse() return path - def create_route_from_path(self, path, from_node_id: bytes) -> Sequence[RouteEdge]: + def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]: assert type(from_node_id) is bytes if path is None: raise Exception('cannot create route from None path') @@ -591,6 +590,10 @@ class LNPathFinder(PrintError): channel_policy = channel_info.get_policy_for_node(prev_node_id) if channel_policy is None: raise Exception('cannot find channel policy for short_channel_id: {}'.format(bh2u(short_channel_id))) - route.append(RouteEdge(node_id, short_channel_id, channel_policy)) + route.append(RouteEdge(node_id, + short_channel_id, + channel_policy.fee_base_msat, + channel_policy.fee_proportional_millionths, + channel_policy.cltv_expiry_delta)) prev_node_id = node_id return route diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 7c1fe7086..fdca280cb 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -27,6 +27,7 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, from .lnutil import LOCAL, REMOTE from .lnaddr import lndecode from .i18n import _ +from .lnrouter import RouteEdge NUM_PEERS_TARGET = 4 @@ -237,16 +238,12 @@ class LNWorker(PrintError): def pay(self, invoice, amount_sat=None): addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) payment_hash = addr.paymenthash - invoice_pubkey = addr.pubkey.serialize() amount_sat = (addr.amount * COIN) if addr.amount else amount_sat if amount_sat is None: raise InvoiceError(_("Missing amount")) amount_msat = int(amount_sat * 1000) - # TODO use 'r' field from invoice - path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat) - if path is None: - raise PaymentFailure(_("No path found")) - node_id, short_channel_id = path[0] + route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat) + node_id, short_channel_id = route[0].node_id, route[0].short_channel_id peer = self.peers[node_id] with self.lock: channels = list(self.channels.values()) @@ -255,9 +252,37 @@ class LNWorker(PrintError): break else: raise Exception("ChannelDB returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) - coro = peer.pay(path, chan, amount_msat, payment_hash, invoice_pubkey, addr.min_final_cltv_expiry) + coro = peer.pay(route, chan, amount_msat, payment_hash, addr.min_final_cltv_expiry) return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]: + invoice_pubkey = decoded_invoice.pubkey.serialize() + # use 'r' field from invoice + route = None # type: List[RouteEdge] + for tag_type, data in decoded_invoice.tags: + if tag_type != 'r': continue + private_route = data + if len(private_route) == 0: continue + border_node_pubkey = private_route[0][0] + path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat) + if path is None: continue + route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) + # we need to shift the node pubkey by one towards the destination: + private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey] + private_route_rest = [edge[1:] for edge in private_route] + for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest): + short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest + route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths, + cltv_expiry_delta)) + break + # if could not find route using any hint; try without hint now + if route is None: + path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat) + if path is None: + raise PaymentFailure(_("No path found")) + route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) + return route + def add_invoice(self, amount_sat, message): payment_preimage = os.urandom(32) RHASH = sha256(payment_preimage)