@ -30,10 +30,11 @@ from electrum.channel_db import ChannelDB
from electrum . lnworker import LNWallet , NoPathFound
from electrum . lnmsg import encode_msg , decode_msg
from electrum . logging import console_stderr_handler , Logger
from electrum . lnworker import PaymentInfo , RECEIVED , PR_UNPAID
from electrum . lnworker import PaymentInfo , RECEIVED
from electrum . lnonion import OnionFailureCode
from electrum . lnutil import ChannelBlackList , derive_payment_secret_from_payment_preimage
from electrum . lnutil import LOCAL , REMOTE
from electrum . invoices import PR_PAID , PR_UNPAID
from . test_lnchannel import create_test_channels
from . test_bitcoin import needs_test_with_all_chacha20_implementations
@ -112,7 +113,8 @@ class MockWallet:
class MockLNWallet ( Logger , NetworkRetryManager [ LNPeerAddr ] ) :
def __init__ ( self , * , local_keypair : Keypair , chans : Iterable [ ' Channel ' ] , tx_queue ) :
def __init__ ( self , * , local_keypair : Keypair , chans : Iterable [ ' Channel ' ] , tx_queue , name ) :
self . name = name
Logger . __init__ ( self )
NetworkRetryManager . __init__ ( self , max_retry_delay_normal = 1 , init_retry_delay_normal = 1 )
self . node_keypair = local_keypair
@ -173,6 +175,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def save_channel ( self , chan ) :
print ( " Ignoring channel save " )
def diagnostic_name ( self ) :
return self . name
get_payments = LNWallet . get_payments
get_payment_info = LNWallet . get_payment_info
save_payment_info = LNWallet . save_payment_info
@ -298,8 +303,8 @@ class TestPeer(ElectrumTestCase):
bob_channel . node_id = k1 . pubkey
t1 , t2 = transport_pair ( k1 , k2 , alice_channel . name , bob_channel . name )
q1 , q2 = asyncio . Queue ( ) , asyncio . Queue ( )
w1 = MockLNWallet ( local_keypair = k1 , chans = [ alice_channel ] , tx_queue = q1 )
w2 = MockLNWallet ( local_keypair = k2 , chans = [ bob_channel ] , tx_queue = q2 )
w1 = MockLNWallet ( local_keypair = k1 , chans = [ alice_channel ] , tx_queue = q1 , name = bob_channel . name )
w2 = MockLNWallet ( local_keypair = k2 , chans = [ bob_channel ] , tx_queue = q2 , name = alice_channel . name )
p1 = Peer ( w1 , k2 . pubkey , t1 )
p2 = Peer ( w2 , k1 . pubkey , t2 )
w1 . _peers [ p1 . pubkey ] = p1
@ -324,10 +329,10 @@ class TestPeer(ElectrumTestCase):
trans_bd , trans_db = transport_pair ( key_b , key_d , chan_bd . name , chan_db . name )
trans_cd , trans_dc = transport_pair ( key_c , key_d , chan_cd . name , chan_dc . name )
txq_a , txq_b , txq_c , txq_d = [ asyncio . Queue ( ) for i in range ( 4 ) ]
w_a = MockLNWallet ( local_keypair = key_a , chans = [ chan_ab , chan_ac ] , tx_queue = txq_a )
w_b = MockLNWallet ( local_keypair = key_b , chans = [ chan_ba , chan_bd ] , tx_queue = txq_b )
w_c = MockLNWallet ( local_keypair = key_c , chans = [ chan_ca , chan_cd ] , tx_queue = txq_c )
w_d = MockLNWallet ( local_keypair = key_d , chans = [ chan_db , chan_dc ] , tx_queue = txq_d )
w_a = MockLNWallet ( local_keypair = key_a , chans = [ chan_ab , chan_ac ] , tx_queue = txq_a , name = " alice " )
w_b = MockLNWallet ( local_keypair = key_b , chans = [ chan_ba , chan_bd ] , tx_queue = txq_b , name = " bob " )
w_c = MockLNWallet ( local_keypair = key_c , chans = [ chan_ca , chan_cd ] , tx_queue = txq_c , name = " carol " )
w_d = MockLNWallet ( local_keypair = key_d , chans = [ chan_db , chan_dc ] , tx_queue = txq_d , name = " dave " )
peer_ab = Peer ( w_a , key_b . pubkey , trans_ab )
peer_ac = Peer ( w_a , key_c . pubkey , trans_ac )
peer_ba = Peer ( w_b , key_a . pubkey , trans_ba )
@ -489,11 +494,14 @@ class TestPeer(ElectrumTestCase):
@needs_test_with_all_chacha20_implementations
def test_payment ( self ) :
""" Alice pays Bob a single HTLC via direct channel. """
alice_channel , bob_channel = create_test_channels ( )
p1 , p2 , w1 , w2 , _q1 , _q2 = self . prepare_peers ( alice_channel , bob_channel )
async def pay ( pay_req ) :
async def pay ( lnaddr , pay_req ) :
self . assertEqual ( PR_UNPAID , w2 . get_payment_status ( lnaddr . paymenthash ) )
result , log = await w1 . pay_invoice ( pay_req )
self . assertTrue ( result )
self . assertEqual ( PR_PAID , w2 . get_payment_status ( lnaddr . paymenthash ) )
raise PaymentDone ( )
async def f ( ) :
async with TaskGroup ( ) as group :
@ -503,7 +511,9 @@ class TestPeer(ElectrumTestCase):
await group . spawn ( p2 . htlc_switch ( ) )
await asyncio . sleep ( 0.01 )
lnaddr , pay_req = await self . prepare_invoice ( w2 )
await group . spawn ( pay ( pay_req ) )
invoice_features = lnaddr . get_features ( )
self . assertFalse ( invoice_features . supports ( LnFeatures . BASIC_MPP_OPT ) )
await group . spawn ( pay ( lnaddr , pay_req ) )
with self . assertRaises ( PaymentDone ) :
run ( f ( ) )
@ -614,9 +624,11 @@ class TestPeer(ElectrumTestCase):
def test_payment_multihop ( self ) :
graph = self . prepare_chans_and_peers_in_square ( )
peers = graph . all_peers ( )
async def pay ( pay_req ) :
async def pay ( lnaddr , pay_req ) :
self . assertEqual ( PR_UNPAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
result , log = await graph . w_a . pay_invoice ( pay_req )
self . assertTrue ( result )
self . assertEqual ( PR_PAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
raise PaymentDone ( )
async def f ( ) :
async with TaskGroup ( ) as group :
@ -625,7 +637,7 @@ class TestPeer(ElectrumTestCase):
await group . spawn ( peer . htlc_switch ( ) )
await asyncio . sleep ( 0.2 )
lnaddr , pay_req = await self . prepare_invoice ( graph . w_d , include_routing_hints = True )
await group . spawn ( pay ( pay_req ) )
await group . spawn ( pay ( lnaddr , pay_req ) )
with self . assertRaises ( PaymentDone ) :
run ( f ( ) )
@ -679,9 +691,11 @@ class TestPeer(ElectrumTestCase):
graph . w_b . network . config . set_key ( ' test_fail_htlcs_with_temp_node_failure ' , True )
graph . w_c . network . config . set_key ( ' test_fail_htlcs_with_temp_node_failure ' , True )
peers = graph . all_peers ( )
async def pay ( pay_req ) :
async def pay ( lnaddr , pay_req ) :
self . assertEqual ( PR_UNPAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
result , log = await graph . w_a . pay_invoice ( pay_req )
self . assertFalse ( result )
self . assertEqual ( PR_UNPAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
self . assertEqual ( OnionFailureCode . TEMPORARY_NODE_FAILURE , log [ 0 ] . failure_msg . code )
raise PaymentDone ( )
async def f ( ) :
@ -691,7 +705,7 @@ class TestPeer(ElectrumTestCase):
await group . spawn ( peer . htlc_switch ( ) )
await asyncio . sleep ( 0.2 )
lnaddr , pay_req = await self . prepare_invoice ( graph . w_d , include_routing_hints = True )
await group . spawn ( pay ( pay_req ) )
await group . spawn ( pay ( lnaddr , pay_req ) )
with self . assertRaises ( PaymentDone ) :
run ( f ( ) )
@ -702,12 +716,14 @@ class TestPeer(ElectrumTestCase):
graph = self . prepare_chans_and_peers_in_square ( )
graph . w_c . network . config . set_key ( ' test_fail_htlcs_with_temp_node_failure ' , True )
peers = graph . all_peers ( )
async def pay ( pay_req ) :
async def pay ( lnaddr , pay_req ) :
self . assertEqual ( 500000000000 , graph . chan_ab . balance ( LOCAL ) )
self . assertEqual ( 500000000000 , graph . chan_db . balance ( LOCAL ) )
self . assertEqual ( PR_UNPAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
result , log = await graph . w_a . pay_invoice ( pay_req , attempts = 2 )
self . assertEqual ( 2 , len ( log ) )
self . assertTrue ( result )
self . assertEqual ( PR_PAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
self . assertEqual ( [ graph . chan_ac . short_channel_id , graph . chan_cd . short_channel_id ] ,
[ edge . short_channel_id for edge in log [ 0 ] . route ] )
self . assertEqual ( [ graph . chan_ab . short_channel_id , graph . chan_bd . short_channel_id ] ,
@ -726,7 +742,7 @@ class TestPeer(ElectrumTestCase):
lnaddr , pay_req = await self . prepare_invoice ( graph . w_d , include_routing_hints = True )
invoice_features = lnaddr . get_features ( )
self . assertFalse ( invoice_features . supports ( LnFeatures . BASIC_MPP_OPT ) )
await group . spawn ( pay ( pay_req ) )
await group . spawn ( pay ( lnaddr , pay_req ) )
with self . assertRaises ( PaymentDone ) :
run ( f ( ) )
@ -737,8 +753,10 @@ class TestPeer(ElectrumTestCase):
peers = graph . all_peers ( )
async def pay ( ) :
lnaddr , pay_req = await self . prepare_invoice ( graph . w_d , include_routing_hints = True , amount_msat = amount_to_pay )
self . assertEqual ( PR_UNPAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
result , log = await graph . w_a . pay_invoice ( pay_req , attempts = attempts )
if result :
self . assertEqual ( PR_PAID , graph . w_d . get_payment_status ( lnaddr . paymenthash ) )
raise PaymentDone ( )
else :
raise NoPathFound ( )