|
|
@ -4,6 +4,8 @@ from electrum.ecc import ECPrivkey |
|
|
|
from electrum.lnutil import LNPeerAddr |
|
|
|
from electrum.lntransport import LNResponderTransport, LNTransport |
|
|
|
|
|
|
|
from aiorpcx import TaskGroup |
|
|
|
|
|
|
|
from . import ElectrumTestCase |
|
|
|
from .test_bitcoin import needs_test_with_all_chacha20_implementations |
|
|
|
|
|
|
@ -46,27 +48,53 @@ class TestLNTransport(ElectrumTestCase): |
|
|
|
server_shaked = asyncio.Event() |
|
|
|
responder_key = ECPrivkey.generate_random_key() |
|
|
|
initiator_key = ECPrivkey.generate_random_key() |
|
|
|
messages_sent_by_client = [ |
|
|
|
b'hello from client', |
|
|
|
b'long data from client ' + bytes(range(256)) * 100 + b'... client done', |
|
|
|
b'client is running out of things to say', |
|
|
|
] |
|
|
|
messages_sent_by_server = [ |
|
|
|
b'hello from server', |
|
|
|
b'hello2 from server', |
|
|
|
b'long data from server ' + bytes(range(256)) * 100 + b'... server done', |
|
|
|
] |
|
|
|
async def read_messages(transport, expected_messages): |
|
|
|
ctr = 0 |
|
|
|
async for msg in transport.read_messages(): |
|
|
|
self.assertEqual(expected_messages[ctr], msg) |
|
|
|
ctr += 1 |
|
|
|
if ctr == len(expected_messages): |
|
|
|
return |
|
|
|
async def write_messages(transport, expected_messages): |
|
|
|
for msg in expected_messages: |
|
|
|
transport.send_bytes(msg) |
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
|
|
|
async def cb(reader, writer): |
|
|
|
t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer) |
|
|
|
self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes()) |
|
|
|
t.send_bytes(b'hello from server') |
|
|
|
self.assertEqual(await t.read_messages().__anext__(), b'hello from client') |
|
|
|
async with TaskGroup() as group: |
|
|
|
await group.spawn(read_messages(t, messages_sent_by_client)) |
|
|
|
await group.spawn(write_messages(t, messages_sent_by_server)) |
|
|
|
responder_shaked.set() |
|
|
|
server_future = asyncio.ensure_future(asyncio.start_server(cb, '127.0.0.1', 42898)) |
|
|
|
loop.run_until_complete(server_future) |
|
|
|
server = server_future.result() # type: asyncio.Server |
|
|
|
async def connect(): |
|
|
|
peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes()) |
|
|
|
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None) |
|
|
|
await t.handshake() |
|
|
|
t.send_bytes(b'hello from client') |
|
|
|
self.assertEqual(await t.read_messages().__anext__(), b'hello from server') |
|
|
|
async with TaskGroup() as group: |
|
|
|
await group.spawn(read_messages(t, messages_sent_by_server)) |
|
|
|
await group.spawn(write_messages(t, messages_sent_by_client)) |
|
|
|
server_shaked.set() |
|
|
|
|
|
|
|
try: |
|
|
|
connect_future = asyncio.ensure_future(connect()) |
|
|
|
loop.run_until_complete(responder_shaked.wait()) |
|
|
|
loop.run_until_complete(server_shaked.wait()) |
|
|
|
finally: |
|
|
|
server.close() |
|
|
|
loop.run_until_complete(server.wait_closed()) |
|
|
|
async def f(): |
|
|
|
server = await asyncio.start_server(cb, '127.0.0.1', 42898) |
|
|
|
try: |
|
|
|
async with TaskGroup() as group: |
|
|
|
await group.spawn(connect()) |
|
|
|
await group.spawn(responder_shaked.wait()) |
|
|
|
await group.spawn(server_shaked.wait()) |
|
|
|
finally: |
|
|
|
server.close() |
|
|
|
await server.wait_closed() |
|
|
|
|
|
|
|
loop.run_until_complete(f()) |
|
|
|