@ -8,7 +8,7 @@ import logging
import concurrent
from concurrent import futures
import unittest
from typing import Iterable , NamedTuple , Tuple , List
from typing import Iterable , NamedTuple , Tuple , List , Dict
from aiorpcx import TaskGroup , timeout_after , TaskTimeout
@ -223,6 +223,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
is_trampoline_peer = LNWallet . is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet . wait_for_received_pending_htlcs_to_get_removed
on_proxy_changed = LNWallet . on_proxy_changed
_decode_channel_update_msg = LNWallet . _decode_channel_update_msg
_handle_chanupd_from_failed_htlc = LNWallet . _handle_chanupd_from_failed_htlc
class MockTransport :
@ -349,12 +351,38 @@ class TestPeer(ElectrumTestCase):
p2 . mark_open ( bob_channel )
return p1 , p2 , w1 , w2 , q1 , q2
def prepare_chans_and_peers_in_square ( self ) - > SquareGraph :
def prepare_chans_and_peers_in_square ( self , funds_distribution : Dict [ str , Tuple [ int , int ] ] = None ) - > SquareGraph :
if not funds_distribution :
funds_distribution = { }
key_a , key_b , key_c , key_d = [ keypair ( ) for i in range ( 4 ) ]
chan_ab , chan_ba = create_test_channels ( alice_name = " alice " , bob_name = " bob " , alice_pubkey = key_a . pubkey , bob_pubkey = key_b . pubkey )
chan_ac , chan_ca = create_test_channels ( alice_name = " alice " , bob_name = " carol " , alice_pubkey = key_a . pubkey , bob_pubkey = key_c . pubkey )
chan_bd , chan_db = create_test_channels ( alice_name = " bob " , bob_name = " dave " , alice_pubkey = key_b . pubkey , bob_pubkey = key_d . pubkey )
chan_cd , chan_dc = create_test_channels ( alice_name = " carol " , bob_name = " dave " , alice_pubkey = key_c . pubkey , bob_pubkey = key_d . pubkey )
local_balance , remote_balance = funds_distribution . get ( ' ab ' ) or ( None , None )
chan_ab , chan_ba = create_test_channels (
alice_name = " alice " , bob_name = " bob " ,
alice_pubkey = key_a . pubkey , bob_pubkey = key_b . pubkey ,
local_msat = local_balance ,
remote_msat = remote_balance ,
)
local_balance , remote_balance = funds_distribution . get ( ' ac ' ) or ( None , None )
chan_ac , chan_ca = create_test_channels (
alice_name = " alice " , bob_name = " carol " ,
alice_pubkey = key_a . pubkey , bob_pubkey = key_c . pubkey ,
local_msat = local_balance ,
remote_msat = remote_balance ,
)
local_balance , remote_balance = funds_distribution . get ( ' bd ' ) or ( None , None )
chan_bd , chan_db = create_test_channels (
alice_name = " bob " , bob_name = " dave " ,
alice_pubkey = key_b . pubkey , bob_pubkey = key_d . pubkey ,
local_msat = local_balance ,
remote_msat = remote_balance ,
)
local_balance , remote_balance = funds_distribution . get ( ' cd ' ) or ( None , None )
chan_cd , chan_dc = create_test_channels (
alice_name = " carol " , bob_name = " dave " ,
alice_pubkey = key_c . pubkey , bob_pubkey = key_d . pubkey ,
local_msat = local_balance ,
remote_msat = remote_balance ,
)
trans_ab , trans_ba = transport_pair ( key_a , key_b , chan_ab . name , chan_ba . name )
trans_ac , trans_ca = transport_pair ( key_a , key_c , chan_ac . name , chan_ca . name )
trans_bd , trans_db = transport_pair ( key_b , key_d , chan_bd . name , chan_db . name )
@ -778,6 +806,43 @@ class TestPeer(ElectrumTestCase):
with self . assertRaises ( PaymentDone ) :
run ( f ( ) )
@needs_test_with_all_chacha20_implementations
def test_payment_with_temp_channel_failure ( self ) :
# prepare channels such that a temporary channel failure happens at c->d
funds_distribution = {
' ac ' : ( 200_000_000 , 200_000_000 ) , # low fees
' cd ' : ( 50_000_000 , 200_000_000 ) , # low fees
' ab ' : ( 200_000_000 , 200_000_000 ) , # high fees
' bd ' : ( 200_000_000 , 200_000_000 ) , # high fees
}
# the payment happens in three attempts:
# 1. along ac->cd due to low fees with temp channel failure:
# with chanupd: ORPHANED, private channel update
# 2. along ac->cd with temp channel failure:
# with chanupd: ORPHANED, private channel update, but already received, channel gets blacklisted
# 3. along ab->bd with success
amount_to_pay = 100_000_000
graph = self . prepare_chans_and_peers_in_square ( funds_distribution )
peers = graph . all_peers ( )
async def pay ( lnaddr , pay_req ) :
self . assertEqual ( PR_UNPAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
result , log = await graph . w_a . pay_invoice ( pay_req , attempts = 3 )
self . assertTrue ( result )
self . assertEqual ( PR_PAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
self . assertEqual ( OnionFailureCode . TEMPORARY_CHANNEL_FAILURE , log [ 0 ] . failure_msg . code )
self . assertEqual ( OnionFailureCode . TEMPORARY_CHANNEL_FAILURE , log [ 1 ] . failure_msg . code )
raise PaymentDone ( )
async def f ( ) :
async with TaskGroup ( ) as group :
for peer in peers :
await group . spawn ( peer . _message_loop ( ) )
await group . spawn ( peer . htlc_switch ( ) )
await asyncio . sleep ( 0.2 )
lnaddr , pay_req = await self . prepare_invoice ( graph . w_d , amount_msat = amount_to_pay , include_routing_hints = True )
await group . spawn ( pay ( lnaddr , pay_req ) )
with self . assertRaises ( PaymentDone ) :
run ( f ( ) )
def _run_mpp ( self , graph , kwargs1 , kwargs2 ) :
self . assertEqual ( 500_000_000_000 , graph . chan_ab . balance ( LOCAL ) )
self . assertEqual ( 500_000_000_000 , graph . chan_ac . balance ( LOCAL ) )