Browse Source

Merge pull request #7113 from bitromortac/2103-temp-chan-fail

forwarding: temp chan fail on insufficient funds
patch-4
ThomasV 4 years ago
committed by GitHub
parent
commit
bf5aa1d690
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 25
      electrum/lnpeer.py
  2. 77
      electrum/tests/test_lnpeer.py

25
electrum/lnpeer.py

@ -1360,34 +1360,36 @@ class Peer(Logger):
raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:] outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big") outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
outgoing_chan_upd_message = outgoing_chan_upd_len + outgoing_chan_upd
if not next_chan.can_send_update_add_htlc(): if not next_chan.can_send_update_add_htlc():
self.logger.info(f"cannot forward htlc. next_chan {next_chan_scid} cannot send ctx updates. " self.logger.info(f"cannot forward htlc. next_chan {next_chan_scid} cannot send ctx updates. "
f"chan state {next_chan.get_state()!r}, peer state: {next_chan.peer_state!r}") f"chan state {next_chan.get_state()!r}, peer state: {next_chan.peer_state!r}")
data = outgoing_chan_upd_len + outgoing_chan_upd raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data) try:
next_amount_msat_htlc = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
except:
raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
if not next_chan.can_pay(next_amount_msat_htlc):
self.logger.info(f"cannot forward htlc due to transient errors (likely due to insufficient funds)")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
try: try:
next_cltv_expiry = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] next_cltv_expiry = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"]
except: except:
raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
if htlc.cltv_expiry - next_cltv_expiry < next_chan.forwarding_cltv_expiry_delta: if htlc.cltv_expiry - next_cltv_expiry < next_chan.forwarding_cltv_expiry_delta:
data = htlc.cltv_expiry.to_bytes(4, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd data = htlc.cltv_expiry.to_bytes(4, byteorder="big") + outgoing_chan_upd_message
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data=data) raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data=data)
if htlc.cltv_expiry - lnutil.MIN_FINAL_CLTV_EXPIRY_ACCEPTED <= local_height \ if htlc.cltv_expiry - lnutil.MIN_FINAL_CLTV_EXPIRY_ACCEPTED <= local_height \
or next_cltv_expiry <= local_height: or next_cltv_expiry <= local_height:
data = outgoing_chan_upd_len + outgoing_chan_upd raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_SOON, data=outgoing_chan_upd_message)
raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_SOON, data=data)
if max(htlc.cltv_expiry, next_cltv_expiry) > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE: if max(htlc.cltv_expiry, next_cltv_expiry) > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'')
try:
next_amount_msat_htlc = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
except:
raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
forwarding_fees = fee_for_edge_msat( forwarding_fees = fee_for_edge_msat(
forwarded_amount_msat=next_amount_msat_htlc, forwarded_amount_msat=next_amount_msat_htlc,
fee_base_msat=next_chan.forwarding_fee_base_msat, fee_base_msat=next_chan.forwarding_fee_base_msat,
fee_proportional_millionths=next_chan.forwarding_fee_proportional_millionths) fee_proportional_millionths=next_chan.forwarding_fee_proportional_millionths)
if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees: if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees:
data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_message
raise OnionRoutingFailure(code=OnionFailureCode.FEE_INSUFFICIENT, data=data) raise OnionRoutingFailure(code=OnionFailureCode.FEE_INSUFFICIENT, data=data)
self.logger.info(f'forwarding htlc to {next_chan.node_id}') self.logger.info(f'forwarding htlc to {next_chan.node_id}')
next_htlc = UpdateAddHtlc( next_htlc = UpdateAddHtlc(
@ -1409,8 +1411,7 @@ class Peer(Logger):
) )
except BaseException as e: except BaseException as e:
self.logger.info(f"failed to forward htlc: error sending message. {e}") self.logger.info(f"failed to forward htlc: error sending message. {e}")
data = outgoing_chan_upd_len + outgoing_chan_upd raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data)
return next_chan_scid, next_htlc.htlc_id return next_chan_scid, next_htlc.htlc_id
def maybe_forward_trampoline( def maybe_forward_trampoline(

77
electrum/tests/test_lnpeer.py

@ -8,7 +8,7 @@ import logging
import concurrent import concurrent
from concurrent import futures from concurrent import futures
import unittest import unittest
from typing import Iterable, NamedTuple, Tuple, List from typing import Iterable, NamedTuple, Tuple, List, Dict
from aiorpcx import TaskGroup, timeout_after, TaskTimeout from aiorpcx import TaskGroup, timeout_after, TaskTimeout
@ -221,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
is_trampoline_peer = LNWallet.is_trampoline_peer is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
on_proxy_changed = LNWallet.on_proxy_changed on_proxy_changed = LNWallet.on_proxy_changed
_decode_channel_update_msg = LNWallet._decode_channel_update_msg
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc
class MockTransport: class MockTransport:
@ -347,12 +349,38 @@ class TestPeer(TestCaseForTestnet):
p2.mark_open(bob_channel) p2.mark_open(bob_channel)
return p1, p2, w1, w2, q1, q2 return p1, p2, w1, w2, q1, q2
def prepare_chans_and_peers_in_square(self) -> SquareGraph: def prepare_chans_and_peers_in_square(self, funds_distribution: Dict[str, Tuple[int, int]]=None) -> SquareGraph:
if not funds_distribution:
funds_distribution = {}
key_a, key_b, key_c, key_d = [keypair() for i in range(4)] key_a, key_b, key_c, key_d = [keypair() for i in range(4)]
chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey) local_balance, remote_balance = funds_distribution.get('ab') or (None, None)
chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey) chan_ab, chan_ba = create_test_channels(
chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey) alice_name="alice", bob_name="bob",
chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey) alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey,
local_msat=local_balance,
remote_msat=remote_balance,
)
local_balance, remote_balance = funds_distribution.get('ac') or (None, None)
chan_ac, chan_ca = create_test_channels(
alice_name="alice", bob_name="carol",
alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey,
local_msat=local_balance,
remote_msat=remote_balance,
)
local_balance, remote_balance = funds_distribution.get('bd') or (None, None)
chan_bd, chan_db = create_test_channels(
alice_name="bob", bob_name="dave",
alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey,
local_msat=local_balance,
remote_msat=remote_balance,
)
local_balance, remote_balance = funds_distribution.get('cd') or (None, None)
chan_cd, chan_dc = create_test_channels(
alice_name="carol", bob_name="dave",
alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey,
local_msat=local_balance,
remote_msat=remote_balance,
)
trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name)
trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name) trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name)
trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name)
@ -776,6 +804,43 @@ class TestPeer(TestCaseForTestnet):
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
@needs_test_with_all_chacha20_implementations
def test_payment_with_temp_channel_failure(self):
# prepare channels such that a temporary channel failure happens at c->d
funds_distribution = {
'ac': (200_000_000, 200_000_000), # low fees
'cd': (50_000_000, 200_000_000), # low fees
'ab': (200_000_000, 200_000_000), # high fees
'bd': (200_000_000, 200_000_000), # high fees
}
# the payment happens in three attempts:
# 1. along ac->cd due to low fees with temp channel failure:
# with chanupd: ORPHANED, private channel update
# 2. along ac->cd with temp channel failure:
# with chanupd: ORPHANED, private channel update, but already received, channel gets blacklisted
# 3. along ab->bd with success
amount_to_pay = 100_000_000
graph = self.prepare_chans_and_peers_in_square(funds_distribution)
peers = graph.all_peers()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
result, log = await graph.w_a.pay_invoice(pay_req, attempts=3)
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash))
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code)
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[1].failure_msg.code)
raise PaymentDone()
async def f():
async with TaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, amount_msat=amount_to_pay, include_routing_hints=True)
await group.spawn(pay(lnaddr, pay_req))
with self.assertRaises(PaymentDone):
run(f())
def _run_mpp(self, graph, kwargs1, kwargs2): def _run_mpp(self, graph, kwargs1, kwargs2):
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))

Loading…
Cancel
Save