diff --git a/electrum/lnonion.py b/electrum/lnonion.py index 151993203..3e8814c9b 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -161,6 +161,9 @@ class OnionHopsDataSingle: # called HopData in lnd assert len(ret.hmac) == PER_HOP_HMAC_SIZE return ret + def __repr__(self): + return f"" + class OnionPacket: @@ -265,22 +268,24 @@ def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, final_ if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: raise PaymentFailure(f"too long route ({len(route)} edges)") + # payload that will be seen by the last hop: amt = amount_msat cltv = final_cltv hop_payload = { "amt_to_forward": {"amt_to_forward": amt}, "outgoing_cltv_value": {"outgoing_cltv_value": cltv}, - "short_channel_id": {"short_channel_id": b"\x00" * 8}, # TODO omit if tlv } - hops_data = [OnionHopsDataSingle(is_tlv_payload=False, # TODO + hops_data = [OnionHopsDataSingle(is_tlv_payload=route[-1].has_feature_varonion(), payload=hop_payload)] - for route_edge in reversed(route[1:]): + # payloads, backwards from last hop (but excluding the first edge): + for edge_index in range(len(route) - 1, 0, -1): + route_edge = route[edge_index] hop_payload = { "amt_to_forward": {"amt_to_forward": amt}, "outgoing_cltv_value": {"outgoing_cltv_value": cltv}, "short_channel_id": {"short_channel_id": route_edge.short_channel_id}, } - hops_data += [OnionHopsDataSingle(is_tlv_payload=False, # TODO + hops_data += [OnionHopsDataSingle(is_tlv_payload=route[edge_index-1].has_feature_varonion(), payload=hop_payload)] amt += route_edge.fee_for_edge(amt) cltv += route_edge.cltv_expiry_delta diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 4e7c95f47..fffc1e5fd 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1040,12 +1040,15 @@ class Peer(Logger): def pay(self, route: 'LNPaymentRoute', chan: Channel, amount_msat: int, payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc: assert amount_msat > 0, "amount_msat is not greater zero" + assert len(route) > 0 if not chan.can_send_update_add_htlc(): raise PaymentFailure("Channel cannot send update_add_htlc") + # add features learned during "init" for direct neighbour: + route[0].node_features |= self.features local_height = self.network.get_local_height() # create onion packet final_cltv = local_height + min_final_cltv_expiry - hops_data, amount_msat, cltv = calc_hops_data_for_payment(route, amount_msat, final_cltv) # TODO varonion + hops_data, amount_msat, cltv = calc_hops_data_for_payment(route, amount_msat, final_cltv) assert final_cltv <= cltv, (final_cltv, cltv) secret_key = os.urandom(32) onion = new_onion_packet([x.node_id for x in route], secret_key, hops_data, associated_data=payment_hash) @@ -1055,7 +1058,8 @@ class Peer(Logger): htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time())) htlc = chan.add_htlc(htlc) chan.set_onion_key(htlc.htlc_id, secret_key) - self.logger.info(f"starting payment. len(route)={len(route)}. route: {route}. htlc: {htlc}") + self.logger.info(f"starting payment. len(route)={len(route)}. route: {route}. " + f"htlc: {htlc}. hops_data={hops_data!r}") self.send_message( "update_add_htlc", channel_id=chan.channel_id, diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index eb38e46ec..bf266de21 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -27,11 +27,13 @@ import queue from collections import defaultdict from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set +import attr + from .util import bh2u, profiler from .logging import Logger -from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID -from .channel_db import ChannelDB, Policy -from .lnutil import NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE +from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures, + NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE) +from .channel_db import ChannelDB, Policy, NodeInfo if TYPE_CHECKING: from .lnchannel import Channel @@ -48,13 +50,15 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000) -class RouteEdge(NamedTuple): +@attr.s +class RouteEdge: """if you travel through short_channel_id, you will reach node_id""" - node_id: bytes - short_channel_id: ShortChannelID - fee_base_msat: int - fee_proportional_millionths: int - cltv_expiry_delta: int + node_id = attr.ib(type=bytes, kw_only=True) + short_channel_id = attr.ib(type=ShortChannelID, kw_only=True) + fee_base_msat = attr.ib(type=int, kw_only=True) + fee_proportional_millionths = attr.ib(type=int, kw_only=True) + cltv_expiry_delta = attr.ib(type=int, kw_only=True) + node_features = attr.ib(type=int, kw_only=True) # note: for end node! def fee_for_edge(self, amount_msat: int) -> int: return fee_for_edge_msat(forwarded_amount_msat=amount_msat, @@ -63,14 +67,16 @@ class RouteEdge(NamedTuple): @classmethod def from_channel_policy(cls, channel_policy: 'Policy', - short_channel_id: bytes, end_node: bytes) -> 'RouteEdge': + short_channel_id: bytes, end_node: bytes, *, + node_info: Optional[NodeInfo]) -> 'RouteEdge': assert isinstance(short_channel_id, bytes) assert type(end_node) is bytes - return RouteEdge(end_node, - ShortChannelID.normalize(short_channel_id), - channel_policy.fee_base_msat, - channel_policy.fee_proportional_millionths, - channel_policy.cltv_expiry_delta) + return RouteEdge(node_id=end_node, + short_channel_id=ShortChannelID.normalize(short_channel_id), + fee_base_msat=channel_policy.fee_base_msat, + fee_proportional_millionths=channel_policy.fee_proportional_millionths, + cltv_expiry_delta=channel_policy.cltv_expiry_delta, + node_features=node_info.features if node_info else 0) def is_sane_to_use(self, amount_msat: int) -> bool: # TODO revise ad-hoc heuristics @@ -82,6 +88,10 @@ class RouteEdge(NamedTuple): return False return True + def has_feature_varonion(self) -> bool: + features = self.node_features + return bool(features & LnFeatures.VAR_ONION_REQ or features & LnFeatures.VAR_ONION_OPT) + LNPaymentRoute = Sequence[RouteEdge] @@ -154,7 +164,9 @@ class LNPathFinder(Logger): if channel_policy.htlc_maximum_msat is not None and \ payment_amt_msat > channel_policy.htlc_maximum_msat: return float('inf'), 0 # payment amount too large - route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) + node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) + route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node, + node_info=node_info) if not route_edge.is_sane_to_use(payment_amt_msat): return float('inf'), 0 # thanks but no thanks @@ -268,6 +280,8 @@ class LNPathFinder(Logger): my_channels=my_channels) if channel_policy is None: raise NoChannelPolicy(short_channel_id) - route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id)) + node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id) + route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id, + node_info=node_info)) prev_node_id = node_id return route diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 9038b5e92..e4855d3dd 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -8,8 +8,8 @@ import json from collections import namedtuple, defaultdict from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence import re -import attr +import attr from aiorpcx import NetAddress from .util import bfh, bh2u, inv_dict, UserFacingException @@ -838,6 +838,7 @@ LN_FEATURES_IMPLEMENTED = ( | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ | LnFeatures.GOSSIP_QUERIES_OPT | LnFeatures.GOSSIP_QUERIES_REQ | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ + | LnFeatures.VAR_ONION_OPT | LnFeatures.VAR_ONION_REQ ) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 592bb5951..833f048cc 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -150,6 +150,7 @@ class LNWorker(Logger): self.features = LnFeatures(0) self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT + self.features |= LnFeatures.VAR_ONION_OPT def channels_for_peer(self, node_id): return {} @@ -1047,7 +1048,7 @@ class LNWallet(LNWorker): return addr @profiler - def _create_route_from_invoice(self, decoded_invoice) -> LNPaymentRoute: + def _create_route_from_invoice(self, decoded_invoice: 'LnAddr') -> LNPaymentRoute: amount_msat = int(decoded_invoice.amount * COIN * 1000) invoice_pubkey = decoded_invoice.pubkey.serialize() # use 'r' field from invoice @@ -1091,8 +1092,13 @@ class LNWallet(LNWorker): fee_base_msat = channel_policy.fee_base_msat fee_proportional_millionths = channel_policy.fee_proportional_millionths cltv_expiry_delta = channel_policy.cltv_expiry_delta - route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths, - cltv_expiry_delta)) + node_info = self.channel_db.get_node_info_for_node_id(node_id=node_pubkey) + route.append(RouteEdge(node_id=node_pubkey, + short_channel_id=short_channel_id, + fee_base_msat=fee_base_msat, + fee_proportional_millionths=fee_proportional_millionths, + cltv_expiry_delta=cltv_expiry_delta, + node_features=node_info.features if node_info else 0)) prev_node_id = node_pubkey # test sanity if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): @@ -1111,6 +1117,11 @@ class LNWallet(LNWorker): if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): self.logger.info(f"rejecting insane route {route}") raise NoPathFound() + assert len(route) > 0 + assert route[-1].node_id == invoice_pubkey + # add features from invoice + invoice_features = decoded_invoice.get_tag('9') or 0 + route[-1].node_features |= invoice_features return route def add_request(self, amount_sat, message, expiry): @@ -1141,7 +1152,8 @@ class LNWallet(LNWorker): lnaddr = LnAddr(payment_hash, amount_btc, tags=[('d', message), ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), - ('x', expiry)] + ('x', expiry), + ('9', self.features.for_invoice())] + routing_hints, date = timestamp) invoice = lnencode(lnaddr, self.node_keypair.privkey)