Browse Source

Trampoline MPP aggregation:

- trampoline node is the final recipient of MPP
 - each trampoline receives a bucket of HTLCs
 - if a HTLC from a bucket fails, wait for the entire bucket to fail
 - move trampoline route and onion code into trampoline module
patch-4
ThomasV 4 years ago
parent
commit
259dacd56f
  1. 81
      electrum/lnpeer.py
  2. 370
      electrum/lnworker.py
  3. 9
      electrum/tests/test_lnpeer.py
  4. 228
      electrum/trampoline.py

81
electrum/lnpeer.py

@ -1195,9 +1195,16 @@ class Peer(Logger):
sig_64, htlc_sigs = chan.sign_next_commitment() sig_64, htlc_sigs = chan.sign_next_commitment()
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
def pay(self, *, route: 'LNPaymentRoute', chan: Channel, amount_msat: int, def pay(self, *,
total_msat: int, payment_hash: bytes, min_final_cltv_expiry: int, route: 'LNPaymentRoute',
payment_secret: bytes = None, fwd_trampoline_onion=None) -> UpdateAddHtlc: chan: Channel,
amount_msat: int,
total_msat: int,
payment_hash: bytes,
min_final_cltv_expiry: int,
payment_secret: bytes = None,
trampoline_onion=None) -> UpdateAddHtlc:
assert amount_msat > 0, "amount_msat is not greater zero" assert amount_msat > 0, "amount_msat is not greater zero"
assert len(route) > 0 assert len(route) > 0
if not chan.can_send_update_add_htlc(): if not chan.can_send_update_add_htlc():
@ -1211,78 +1218,26 @@ class Peer(Logger):
amount_msat, amount_msat,
final_cltv, final_cltv,
total_msat=total_msat, total_msat=total_msat,
payment_secret=payment_secret payment_secret=payment_secret)
) num_hops = len(hops_data)
self.logger.info(f"lnpeer.pay len(route)={len(route)}") self.logger.info(f"lnpeer.pay len(route)={len(route)}")
for i in range(len(route)): for i in range(len(route)):
self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}") self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}")
assert final_cltv <= cltv, (final_cltv, cltv) assert final_cltv <= cltv, (final_cltv, cltv)
session_key = os.urandom(32) # session_key session_key = os.urandom(32) # session_key
# detect trampoline hops
payment_path_pubkeys = [x.node_id for x in route]
num_hops = len(payment_path_pubkeys)
for i in range(num_hops-1):
route_edge = route[i]
next_edge = route[i+1]
if route_edge.is_trampoline():
assert next_edge.is_trampoline()
self.logger.info(f'trampoline hop at position {i}')
hops_data[i].payload["outgoing_node_id"] = {"outgoing_node_id":next_edge.node_id}
if route_edge.invoice_features:
hops_data[i].payload["invoice_features"] = {"invoice_features":route_edge.invoice_features}
if route_edge.invoice_routing_info:
hops_data[i].payload["invoice_routing_info"] = {"invoice_routing_info":route_edge.invoice_routing_info}
# only for final, legacy
if i == num_hops - 2:
self.logger.info(f'adding payment secret for legacy trampoline')
hops_data[i].payload["payment_data"] = {
"payment_secret":payment_secret,
"total_msat": amount_msat,
}
# if we are forwarding a trampoline payment, add trampoline onion # if we are forwarding a trampoline payment, add trampoline onion
if fwd_trampoline_onion: if trampoline_onion:
self.logger.info(f'adding trampoline onion to final payload') self.logger.info(f'adding trampoline onion to final payload')
trampoline_payload = hops_data[num_hops-2].payload trampoline_payload = hops_data[num_hops-2].payload
trampoline_payload["trampoline_onion_packet"] = { trampoline_payload["trampoline_onion_packet"] = {
"version": fwd_trampoline_onion.version, "version": trampoline_onion.version,
"public_key": fwd_trampoline_onion.public_key, "public_key": trampoline_onion.public_key,
"hops_data": fwd_trampoline_onion.hops_data, "hops_data": trampoline_onion.hops_data,
"hmac": fwd_trampoline_onion.hmac "hmac": trampoline_onion.hmac
} }
# create trampoline onion
for i in range(num_hops):
route_edge = route[i]
if route_edge.is_trampoline():
self.logger.info(f'first trampoline hop at position {i}')
self.logger.info(f'inner onion: {hops_data[i:]}')
trampoline_session_key = os.urandom(32)
trampoline_onion = new_onion_packet(payment_path_pubkeys[i:], trampoline_session_key, hops_data[i:], associated_data=payment_hash, trampoline=True)
# drop hop_data
payment_path_pubkeys = payment_path_pubkeys[:i]
hops_data = hops_data[:i]
# we must generate a different secret for the outer onion
outer_payment_secret = os.urandom(32)
# trampoline_payload is a final payload
trampoline_payload = hops_data[i-1].payload
p = trampoline_payload.pop('short_channel_id')
amt_to_forward = trampoline_payload["amt_to_forward"]["amt_to_forward"]
trampoline_payload["payment_data"] = {
"payment_secret":outer_payment_secret,
"total_msat": amt_to_forward
}
trampoline_payload["trampoline_onion_packet"] = {
"version": trampoline_onion.version,
"public_key": trampoline_onion.public_key,
"hops_data": trampoline_onion.hops_data,
"hmac": trampoline_onion.hmac
}
break
# create onion packet # create onion packet
payment_path_pubkeys = [x.node_id for x in route]
onion = new_onion_packet(payment_path_pubkeys, session_key, hops_data, associated_data=payment_hash) # must use another sessionkey onion = new_onion_packet(payment_path_pubkeys, session_key, hops_data, associated_data=payment_hash) # must use another sessionkey
self.logger.info(f"starting payment. len(route)={len(hops_data)}.") self.logger.info(f"starting payment. len(route)={len(hops_data)}.")
# create htlc # create htlc
if cltv > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE: if cltv > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:

370
electrum/lnworker.py

@ -78,6 +78,7 @@ from .channel_db import get_mychannel_info, get_mychannel_policy
from .submarine_swaps import SwapManager from .submarine_swaps import SwapManager
from .channel_db import ChannelInfo, Policy from .channel_db import ChannelInfo, Policy
from .mpp_split import suggest_splits from .mpp_split import suggest_splits
from .trampoline import create_trampoline_route_and_onion
if TYPE_CHECKING: if TYPE_CHECKING:
from .network import Network from .network import Network
@ -163,45 +164,6 @@ def trampolines_by_id():
is_hardcoded_trampoline = lambda node_id: node_id in trampolines_by_id().keys() is_hardcoded_trampoline = lambda node_id: node_id in trampolines_by_id().keys()
# trampoline nodes are supposed to advertise their fee and cltv in node_update message
TRAMPOLINE_FEES = [
{
'fee_base_msat': 0,
'fee_proportional_millionths': 0,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 1000,
'fee_proportional_millionths': 100,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 3000,
'fee_proportional_millionths': 100,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 5000,
'fee_proportional_millionths': 500,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 7000,
'fee_proportional_millionths': 1000,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 12000,
'fee_proportional_millionths': 3000,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 100000,
'fee_proportional_millionths': 3000,
'cltv_expiry_delta': 576,
},
]
class PaymentInfo(NamedTuple): class PaymentInfo(NamedTuple):
payment_hash: bytes payment_hash: bytes
@ -658,7 +620,8 @@ class LNWallet(LNWorker):
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route, amount_for_receiver self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat
self.sent_buckets = defaultdict(set)
self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
@ -1019,6 +982,8 @@ class LNWallet(LNWorker):
r_tags=decoded_invoice.get_routing_info('r'), r_tags=decoded_invoice.get_routing_info('r'),
t_tags=decoded_invoice.get_routing_info('t'), t_tags=decoded_invoice.get_routing_info('t'),
invoice_features=decoded_invoice.get_tag('9') or 0, invoice_features=decoded_invoice.get_tag('9') or 0,
payment_hash=decoded_invoice.paymenthash,
payment_secret=decoded_invoice.payment_secret,
full_path=full_path) full_path=full_path)
@log_exceptions @log_exceptions
@ -1113,17 +1078,25 @@ class LNWallet(LNWorker):
# note: path-finding runs in a separate thread so that we don't block the asyncio loop # note: path-finding runs in a separate thread so that we don't block the asyncio loop
# graph updates might occur during the computation # graph updates might occur during the computation
routes = await run_in_thread(partial( routes = await run_in_thread(partial(
self.create_routes_for_payment, amount_to_send, node_pubkey, self.create_routes_for_payment,
min_cltv_expiry, r_tags, t_tags, invoice_features, full_path=full_path)) amount_msat=amount_to_send,
invoice_pubkey=node_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
t_tags=t_tags,
invoice_features=invoice_features,
full_path=full_path,
payment_hash=payment_hash,
payment_secret=payment_secret))
# 2. send htlcs # 2. send htlcs
for route, amount_msat in routes: for route, amount_msat, total_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes:
await self.pay_to_route( await self.pay_to_route(
route=route, route=route,
amount_msat=amount_msat, amount_msat=amount_msat,
total_msat=amount_to_pay, total_msat=total_msat,
payment_hash=payment_hash, payment_hash=payment_hash,
payment_secret=payment_secret, payment_secret=bucket_payment_secret,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=cltv_delta,
trampoline_onion=trampoline_onion) trampoline_onion=trampoline_onion)
amount_inflight += amount_msat amount_inflight += amount_msat
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex()) util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
@ -1165,8 +1138,11 @@ class LNWallet(LNWorker):
payment_hash=payment_hash, payment_hash=payment_hash,
min_final_cltv_expiry=min_cltv_expiry, min_final_cltv_expiry=min_cltv_expiry,
payment_secret=payment_secret, payment_secret=payment_secret,
fwd_trampoline_onion=trampoline_onion) trampoline_onion=trampoline_onion)
self.sent_htlcs_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route, amount_msat
key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_routes[key] = route, payment_secret, amount_msat, total_msat
self.sent_buckets[payment_secret] = total_msat
util.trigger_callback('htlc_added', chan, htlc, SENT) util.trigger_callback('htlc_added', chan, htlc, SENT)
def handle_error_code_from_failed_htlc(self, htlc_log): def handle_error_code_from_failed_htlc(self, htlc_log):
@ -1293,16 +1269,6 @@ class LNWallet(LNWorker):
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
return addr return addr
def encode_routing_info(self, r_tags):
import bitstring
result = bitstring.BitArray()
for route in r_tags:
result.append(bitstring.pack('uint:8', len(route)))
for step in route:
pubkey, channel, feebase, feerate, cltv = step
result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
return result.tobytes()
def is_trampoline_peer(self, node_id: bytes) -> bool: def is_trampoline_peer(self, node_id: bytes) -> bool:
# until trampoline is advertised in lnfeatures, check against hardcoded list # until trampoline is advertised in lnfeatures, check against hardcoded list
if is_hardcoded_trampoline(node_id): if is_hardcoded_trampoline(node_id):
@ -1318,178 +1284,150 @@ class LNWallet(LNWorker):
else: else:
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
def create_trampoline_route(
self, amount_msat:int,
min_cltv_expiry:int,
invoice_pubkey:bytes,
invoice_features:int,
channels: List[Channel],
r_tags, t_tags) -> LNPaymentRoute:
""" return the route that leads to trampoline, and the trampoline fake edge"""
invoice_features = LnFeatures(invoice_features)
# We do not set trampoline_routing_opt in our invoices, because the spec is not ready
# Do not use t_tags if the flag is set, because we the format is not decided yet
if invoice_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT):
is_legacy = False
if len(r_tags) > 0 and len(r_tags[0]) == 1:
pubkey, scid, feebase, feerate, cltv = r_tags[0][0]
t_tag = pubkey, feebase, feerate, cltv
else:
t_tag = None
elif len(t_tags) > 0:
is_legacy = False
t_tag = t_tags[0]
else:
is_legacy = True
# fee level. the same fee is used for all trampolines
if self.trampoline_fee_level < len(TRAMPOLINE_FEES):
params = TRAMPOLINE_FEES[self.trampoline_fee_level]
else:
raise NoPathFound()
# Find a trampoline. We assume we have a direct channel to trampoline
for chan in channels:
if not self.is_trampoline_peer(chan.node_id):
continue
trampoline_short_channel_id = chan.short_channel_id
trampoline_node_id = chan.node_id
# use attempt number to decide fee and second trampoline
# we need a state with the list of nodes we have not tried
# add optional second trampoline
trampoline2 = None
if is_legacy:
for node_id in self.trampoline2_list:
if node_id != trampoline_node_id:
trampoline2 = node_id
break
# node_features is only used to determine is_tlv
trampoline_features = LnFeatures.VAR_ONION_OPT
# hop to trampoline
route = [
RouteEdge(
node_id=trampoline_node_id,
short_channel_id=trampoline_short_channel_id,
fee_base_msat=0,
fee_proportional_millionths=0,
cltv_expiry_delta=0,
node_features=trampoline_features)
]
# trampoline hop
route.append(
TrampolineEdge(
node_id=trampoline_node_id,
fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'],
node_features=trampoline_features))
if trampoline2:
route.append(
TrampolineEdge(
node_id=trampoline2,
fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'],
node_features=trampoline_features))
# add routing info
if is_legacy:
invoice_routing_info = self.encode_routing_info(r_tags)
route[-1].invoice_routing_info = invoice_routing_info
route[-1].invoice_features = invoice_features
else:
if t_tag:
pubkey, feebase, feerate, cltv = t_tag
if route[-1].node_id != pubkey:
route.append(
TrampolineEdge(
node_id=pubkey,
fee_base_msat=feebase,
fee_proportional_millionths=feerate,
cltv_expiry_delta=cltv,
node_features=trampoline_features))
# Fake edge (not part of actual route, needed by calc_hops_data)
route.append(
TrampolineEdge(
node_id=invoice_pubkey,
fee_base_msat=0,
fee_proportional_millionths=0,
cltv_expiry_delta=0,
node_features=trampoline_features))
# check that we can pay amount and fees
for edge in route[::-1]:
amount_msat += edge.fee_for_edge(amount_msat)
if not chan.can_pay(amount_msat, check_frozen=True):
continue
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
continue
break
else:
raise NoPathFound()
self.logger.info(f'created route with trampoline: fee_level={self.trampoline_fee_level}, is legacy: {is_legacy}')
self.logger.info(f'first trampoline: {trampoline_node_id.hex()}')
self.logger.info(f'second trampoline: {trampoline2.hex() if trampoline2 else None}')
self.logger.info(f'params: {params}')
return route
@profiler @profiler
def create_routes_for_payment( def create_routes_for_payment(
self, self, *,
amount_msat: int, amount_msat: int,
invoice_pubkey, invoice_pubkey,
min_cltv_expiry, min_cltv_expiry,
r_tags, t_tags, r_tags, t_tags,
invoice_features: int, invoice_features: int,
*, payment_hash,
full_path: LNPaymentPath = None, payment_secret,
) -> Sequence[Tuple[LNPaymentRoute, int]]: full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
"""Creates multiple routes for splitting a payment over the available """Creates multiple routes for splitting a payment over the available
private channels. private channels.
We first try to conduct the payment over a single channel. If that fails We first try to conduct the payment over a single channel. If that fails
and mpp is supported by the receiver, we will split the payment.""" and mpp is supported by the receiver, we will split the payment."""
# It could happen that the pathfinding uses a channel
# in the graph multiple times, meaning we could exhaust
# its capacity. This could be dealt with by temporarily
# iteratively blacklisting channels for this mpp attempt.
invoice_features = LnFeatures(invoice_features) invoice_features = LnFeatures(invoice_features)
# try to send over a single channel trampoline_features = LnFeatures.VAR_ONION_OPT
local_height = self.network.get_local_height()
try: try:
routes = [self.create_route_for_payment( # try to send over a single channel
amount_msat=amount_msat, if not self.channel_db:
invoice_pubkey=invoice_pubkey, for chan in self.channels.values():
min_cltv_expiry=min_cltv_expiry, if not self.is_trampoline_peer(chan.node_id):
r_tags=r_tags, continue
t_tags=t_tags, if chan.is_frozen_for_sending():
invoice_features=invoice_features, continue
outgoing_channel=None, trampoline_onion, trampoline_fee, amount_with_fees, cltv_delta = create_trampoline_route_and_onion(
full_path=full_path, amount_msat=amount_msat,
)] bucket_amount_msat=amount_msat,
min_cltv_expiry=min_cltv_expiry,
invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features,
node_id=chan.node_id,
r_tags=r_tags,
t_tags=t_tags,
payment_hash=payment_hash,
payment_secret=payment_secret,
local_height=local_height,
trampoline_fee_level=self.trampoline_fee_level,
trampoline2_list=self.trampoline2_list)
trampoline_payment_secret = os.urandom(32)
amount_to_send = amount_with_fees + trampoline_fee
if chan.available_to_spend(LOCAL, strict=True) < amount_to_send:
continue
route = [
RouteEdge(
node_id=chan.node_id,
short_channel_id=chan.short_channel_id,
fee_base_msat=0,
fee_proportional_millionths=0,
cltv_expiry_delta=0,
node_features=trampoline_features)
]
routes = [(route, amount_to_send, amount_to_send, cltv_delta, trampoline_payment_secret, trampoline_onion)]
break
else:
raise NoPathFound()
else:
route = self.create_route_for_payment(
amount_msat=amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags, t_tags=t_tags,
invoice_features=invoice_features,
outgoing_channel=None, full_path=full_path)
routes = [(route, amount_msat, amount_msat, min_cltv_expiry, payment_secret, None)]
except NoPathFound: except NoPathFound:
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT): if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
raise raise
channels_with_funds = dict([ channels_with_funds = dict([
(cid, int(chan.available_to_spend(HTLCOwner.LOCAL))) (cid, int(chan.available_to_spend(HTLCOwner.LOCAL)))
for cid, chan in self._channels.items()]) for cid, chan in self._channels.items() if not chan.is_frozen_for_sending()])
self.logger.info(f"channels_with_funds: {channels_with_funds}")
# Create split configurations that are rated according to our # Create split configurations that are rated according to our
# preference -funds = (low rating=high preference). # preference -funds = (low rating=high preference).
split_configurations = suggest_splits(amount_msat, channels_with_funds) split_configurations = suggest_splits(amount_msat, channels_with_funds)
self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
for s in split_configurations: for s in split_configurations:
self.logger.info(f"trying split configuration: {s[0].values()} rating: {s[1]}")
routes = [] routes = []
try: try:
for chanid, part_amount_msat in s[0].items(): buckets = defaultdict(list)
for chan_id, part_amount_msat in s[0].items():
chan = self.channels[chan_id]
if part_amount_msat: if part_amount_msat:
channel = self.channels[chanid] buckets[chan.node_id].append((chan_id, part_amount_msat))
# It could happen that the pathfinding uses a channel for node_id, bucket in buckets.items():
# in the graph multiple times, meaning we could exhaust bucket_amount_msat = sum([x[1] for x in bucket])
# its capacity. This could be dealt with by temporarily if not self.channel_db:
# iteratively blacklisting channels for this mpp attempt. trampoline_onion, trampoline_fee, bucket_amount_with_fees, bucket_cltv_delta = create_trampoline_route_and_onion(
route, amt = self.create_route_for_payment( amount_msat=amount_msat,
amount_msat=part_amount_msat, bucket_amount_msat=bucket_amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=min_cltv_expiry,
invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features,
node_id=node_id,
r_tags=r_tags, r_tags=r_tags,
t_tags=t_tags, t_tags=t_tags,
invoice_features=invoice_features, payment_hash=payment_hash,
outgoing_channel=channel, payment_secret=payment_secret,
full_path=None) local_height=local_height,
routes.append((route, amt)) trampoline_fee_level=self.trampoline_fee_level,
trampoline2_list=self.trampoline2_list)
self.logger.info(f'trampoline fee {trampoline_fee}')
# node_features is only used to determine is_tlv
bucket_payment_secret = os.urandom(32)
for chan_id, part_amount_msat in bucket:
chan = self.channels[chan_id]
margin = chan.available_to_spend(LOCAL, strict=True) - part_amount_msat
delta_fee = min(trampoline_fee, margin)
part_amount_msat_with_fees = part_amount_msat + delta_fee
trampoline_fee -= delta_fee
route = [
RouteEdge(
node_id=node_id,
short_channel_id=chan.short_channel_id,
fee_base_msat=0,
fee_proportional_millionths=0,
cltv_expiry_delta=0,
node_features=trampoline_features)
]
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
routes.append((route, part_amount_msat_with_fees, bucket_amount_with_fees, bucket_cltv_delta, bucket_payment_secret, trampoline_onion))
if trampoline_fee > 0:
self.logger.info('not enough marging to pay trampoline fee')
raise NoPathFound()
else:
# then we need bucket_amount_msat that includes the trampoline fees.. then create small routes here
for chan_id, part_amount_msat in bucket:
channel = self.channels[chan_id]
route = self.create_route_for_payment(
amount_msat=part_amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags, t_tags=t_tags,
invoice_features=invoice_features,
outgoing_channel=channel, full_path=None)
routes.append((route, part_amount_msat, bucket_amount_msat, bucket_payment_secret))
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}") self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
break break
except NoPathFound: except NoPathFound:
@ -1499,24 +1437,16 @@ class LNWallet(LNWorker):
return routes return routes
def create_route_for_payment( def create_route_for_payment(
self, self, *,
*,
amount_msat: int, amount_msat: int,
invoice_pubkey: bytes, invoice_pubkey: bytes,
min_cltv_expiry: int, min_cltv_expiry: int,
r_tags, r_tags, t_tags,
t_tags,
invoice_features: int, invoice_features: int,
outgoing_channel: Channel = None, outgoing_channel: Channel = None,
full_path: Optional[LNPaymentPath], full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
) -> Tuple[LNPaymentRoute, int]:
channels = [outgoing_channel] if outgoing_channel else list(self.channels.values()) channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
if not self.channel_db:
route = self.create_trampoline_route(
amount_msat, min_cltv_expiry, invoice_pubkey, invoice_features, channels, r_tags, t_tags)
return route, amount_msat
route = None route = None
scid_to_my_channels = { scid_to_my_channels = {
chan.short_channel_id: chan for chan in channels chan.short_channel_id: chan for chan in channels
@ -1591,7 +1521,7 @@ class LNWallet(LNWorker):
raise LNPathInconsistent("last node_id != invoice pubkey") raise LNPathInconsistent("last node_id != invoice pubkey")
# add features from invoice # add features from invoice
route[-1].node_features |= invoice_features route[-1].node_features |= invoice_features
return route, amount_msat return route
def add_request(self, amount_sat, message, expiry) -> str: def add_request(self, amount_sat, message, expiry) -> str:
coro = self._add_request_coro(amount_sat, message, expiry) coro = self._add_request_coro(amount_sat, message, expiry)
@ -1743,7 +1673,7 @@ class LNWallet(LNWorker):
util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id) util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id)
q = self.sent_htlcs.get(payment_hash) q = self.sent_htlcs.get(payment_hash)
if q: if q:
route, amount_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)] route, payment_secret, amount_msat, bucket_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
htlc_log = HtlcLog( htlc_log = HtlcLog(
success=True, success=True,
route=route, route=route,
@ -1765,7 +1695,10 @@ class LNWallet(LNWorker):
util.trigger_callback('htlc_failed', payment_hash, chan.channel_id) util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)
q = self.sent_htlcs.get(payment_hash) q = self.sent_htlcs.get(payment_hash)
if q: if q:
route, amount_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)] # detect if it is part of a bucket
# if yes, wait until the bucket completely failed
key = (payment_hash, chan.short_channel_id, htlc_id)
route, payment_secret, amount_msat, bucket_msat = self.sent_htlcs_routes[key]
if error_bytes: if error_bytes:
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
try: try:
@ -1778,6 +1711,13 @@ class LNWallet(LNWorker):
assert failure_message is not None assert failure_message is not None
sender_idx = None sender_idx = None
self.logger.info(f"htlc_failed {failure_message}") self.logger.info(f"htlc_failed {failure_message}")
if payment_secret in self.sent_buckets:
self.sent_buckets[payment_secret] -= amount_msat
if self.sent_buckets[payment_secret] > 0:
return
else:
amount_msat = bucket_msat
htlc_log = HtlcLog( htlc_log = HtlcLog(
success=False, success=False,
route=route, route=route,

9
electrum/tests/test_lnpeer.py

@ -135,6 +135,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.received_htlcs = dict() self.received_htlcs = dict()
self.sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs = defaultdict(asyncio.Queue)
self.sent_htlcs_routes = dict() self.sent_htlcs_routes = dict()
self.sent_buckets = defaultdict(set)
def get_invoice_status(self, key): def get_invoice_status(self, key):
pass pass
@ -497,7 +498,7 @@ class TestPeer(ElectrumTestCase):
q2 = w2.sent_htlcs[lnaddr1.paymenthash] q2 = w2.sent_htlcs[lnaddr1.paymenthash]
# alice sends htlc BUT NOT COMMITMENT_SIGNED # alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None p1.maybe_send_commitment = lambda x: None
route1, amount_msat1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0] route1, amount_msat1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0][0:2]
await w1.pay_to_route( await w1.pay_to_route(
route=route1, route=route1,
amount_msat=lnaddr2.get_amount_msat(), amount_msat=lnaddr2.get_amount_msat(),
@ -509,7 +510,7 @@ class TestPeer(ElectrumTestCase):
p1.maybe_send_commitment = _maybe_send_commitment1 p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED # bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None p2.maybe_send_commitment = lambda x: None
route2, amount_msat2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0] route2, amount_msat2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0][0:2]
await w2.pay_to_route( await w2.pay_to_route(
route=route2, route=route2,
amount_msat=lnaddr1.get_amount_msat(), amount_msat=lnaddr1.get_amount_msat(),
@ -664,7 +665,7 @@ class TestPeer(ElectrumTestCase):
await asyncio.wait_for(p1.initialized, 1) await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1) await asyncio.wait_for(p2.initialized, 1)
# alice sends htlc # alice sends htlc
route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0] route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
htlc = p1.pay(route=route, htlc = p1.pay(route=route,
chan=alice_channel, chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(), amount_msat=lnaddr.get_amount_msat(),
@ -760,7 +761,7 @@ class TestPeer(ElectrumTestCase):
pay_req = run(self.prepare_invoice(w2)) pay_req = run(self.prepare_invoice(w2))
lnaddr = w1._check_invoice(pay_req) lnaddr = w1._check_invoice(pay_req)
route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0] route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
assert amount_msat == lnaddr.get_amount_msat() assert amount_msat == lnaddr.get_amount_msat()
run(w1.force_close_channel(alice_channel.channel_id)) run(w1.force_close_channel(alice_channel.channel_id))

228
electrum/trampoline.py

@ -0,0 +1,228 @@
import os
import bitstring
from .lnutil import LnFeatures
from .lnonion import calc_hops_data_for_payment, new_onion_packet
from .lnrouter import RouteEdge, TrampolineEdge, LNPaymentRoute, is_route_sane_to_use
from .lnutil import NoPathFound
from .logging import get_logger, Logger
_logger = get_logger(__name__)
# trampoline nodes are supposed to advertise their fee and cltv in node_update message
TRAMPOLINE_FEES = [
{
'fee_base_msat': 0,
'fee_proportional_millionths': 0,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 1000,
'fee_proportional_millionths': 100,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 3000,
'fee_proportional_millionths': 100,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 5000,
'fee_proportional_millionths': 500,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 7000,
'fee_proportional_millionths': 1000,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 12000,
'fee_proportional_millionths': 3000,
'cltv_expiry_delta': 576,
},
{
'fee_base_msat': 100000,
'fee_proportional_millionths': 3000,
'cltv_expiry_delta': 576,
},
]
def encode_routing_info(r_tags):
result = bitstring.BitArray()
for route in r_tags:
result.append(bitstring.pack('uint:8', len(route)))
for step in route:
pubkey, channel, feebase, feerate, cltv = step
result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
return result.tobytes()
def create_trampoline_route(
amount_msat:int,
bucket_amount_msat:int,
min_cltv_expiry:int,
invoice_pubkey:bytes,
invoice_features:int,
trampoline_node_id,
r_tags, t_tags,
trampoline_fee_level,
trampoline2_list) -> LNPaymentRoute:
invoice_features = LnFeatures(invoice_features)
# We do not set trampoline_routing_opt in our invoices, because the spec is not ready
# Do not use t_tags if the flag is set, because we the format is not decided yet
if invoice_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT):
is_legacy = False
if len(r_tags) > 0 and len(r_tags[0]) == 1:
pubkey, scid, feebase, feerate, cltv = r_tags[0][0]
t_tag = pubkey, feebase, feerate, cltv
else:
t_tag = None
elif len(t_tags) > 0:
is_legacy = False
t_tag = t_tags[0]
else:
is_legacy = True
# fee level. the same fee is used for all trampolines
if trampoline_fee_level < len(TRAMPOLINE_FEES):
params = TRAMPOLINE_FEES[trampoline_fee_level]
else:
raise NoPathFound()
# add optional second trampoline
trampoline2 = None
if is_legacy:
for node_id in trampoline2_list:
if node_id != trampoline_node_id:
trampoline2 = node_id
break
# node_features is only used to determine is_tlv
trampoline_features = LnFeatures.VAR_ONION_OPT
# hop to trampoline
route = []
# trampoline hop
route.append(
TrampolineEdge(
node_id=trampoline_node_id,
fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'],
node_features=trampoline_features))
if trampoline2:
route.append(
TrampolineEdge(
node_id=trampoline2,
fee_base_msat=params['fee_base_msat'],
fee_proportional_millionths=params['fee_proportional_millionths'],
cltv_expiry_delta=params['cltv_expiry_delta'],
node_features=trampoline_features))
# add routing info
if is_legacy:
invoice_routing_info = encode_routing_info(r_tags)
route[-1].invoice_routing_info = invoice_routing_info
route[-1].invoice_features = invoice_features
else:
if t_tag:
pubkey, feebase, feerate, cltv = t_tag
if route[-1].node_id != pubkey:
route.append(
TrampolineEdge(
node_id=pubkey,
fee_base_msat=feebase,
fee_proportional_millionths=feerate,
cltv_expiry_delta=cltv,
node_features=trampoline_features))
# Fake edge (not part of actual route, needed by calc_hops_data)
route.append(
TrampolineEdge(
node_id=invoice_pubkey,
fee_base_msat=0,
fee_proportional_millionths=0,
cltv_expiry_delta=0,
node_features=trampoline_features))
# check that we can pay amount and fees
for edge in route[::-1]:
amount_msat += edge.fee_for_edge(amount_msat)
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
raise NoPathFound()
_logger.info(f'created route with trampoline: fee_level={trampoline_fee_level}, is legacy: {is_legacy}')
_logger.info(f'first trampoline: {trampoline_node_id.hex()}')
_logger.info(f'second trampoline: {trampoline2.hex() if trampoline2 else None}')
_logger.info(f'params: {params}')
return route
def create_trampoline_onion(route, amount_msat, final_cltv, total_msat, payment_hash, payment_secret):
# all edges are trampoline
hops_data, amount_msat, cltv = calc_hops_data_for_payment(
route,
amount_msat,
final_cltv,
total_msat=total_msat,
payment_secret=payment_secret)
# detect trampoline hops.
payment_path_pubkeys = [x.node_id for x in route]
num_hops = len(payment_path_pubkeys)
for i in range(num_hops-1):
route_edge = route[i]
next_edge = route[i+1]
assert route_edge.is_trampoline()
assert next_edge.is_trampoline()
hops_data[i].payload["outgoing_node_id"] = {"outgoing_node_id":next_edge.node_id}
if route_edge.invoice_features:
hops_data[i].payload["invoice_features"] = {"invoice_features":route_edge.invoice_features}
if route_edge.invoice_routing_info:
hops_data[i].payload["invoice_routing_info"] = {"invoice_routing_info":route_edge.invoice_routing_info}
# only for final, legacy
if i == num_hops - 2:
hops_data[i].payload["payment_data"] = {
"payment_secret":payment_secret,
"total_msat": total_msat,
}
trampoline_session_key = os.urandom(32)
trampoline_onion = new_onion_packet(payment_path_pubkeys, trampoline_session_key, hops_data, associated_data=payment_hash, trampoline=True)
return trampoline_onion, amount_msat, cltv
def create_trampoline_route_and_onion(
*,
amount_msat,
bucket_amount_msat,
min_cltv_expiry,
invoice_pubkey,
invoice_features,
node_id,
r_tags, t_tags,
payment_hash,
payment_secret,
local_height:int,
trampoline_fee_level,
trampoline2_list):
# create route for the trampoline_onion
trampoline_route = create_trampoline_route(
amount_msat,
bucket_amount_msat,
min_cltv_expiry,
invoice_pubkey,
invoice_features,
node_id,
r_tags, t_tags,
trampoline_fee_level,
trampoline2_list)
# compute onion and fees
final_cltv = local_height + min_cltv_expiry
trampoline_onion, bucket_amount_with_fees, bucket_cltv = create_trampoline_onion(
trampoline_route,
bucket_amount_msat,
final_cltv,
amount_msat,
payment_hash,
payment_secret)
bucket_cltv_delta = bucket_cltv - local_height
bucket_cltv_delta += trampoline_route[0].cltv_expiry_delta
# trampoline fee for this very trampoline
trampoline_fee = trampoline_route[0].fee_for_edge(bucket_amount_with_fees)
return trampoline_onion, trampoline_fee, bucket_amount_with_fees, bucket_cltv_delta
Loading…
Cancel
Save