diff --git a/electrum/tests/test_lntransport.py b/electrum/tests/test_lntransport.py index 2dc1b4950..8b7d567ba 100644 --- a/electrum/tests/test_lntransport.py +++ b/electrum/tests/test_lntransport.py @@ -4,6 +4,8 @@ from electrum.ecc import ECPrivkey from electrum.lnutil import LNPeerAddr from electrum.lntransport import LNResponderTransport, LNTransport +from aiorpcx import TaskGroup + from . import ElectrumTestCase from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -46,27 +48,53 @@ class TestLNTransport(ElectrumTestCase): server_shaked = asyncio.Event() responder_key = ECPrivkey.generate_random_key() initiator_key = ECPrivkey.generate_random_key() + messages_sent_by_client = [ + b'hello from client', + b'long data from client ' + bytes(range(256)) * 100 + b'... client done', + b'client is running out of things to say', + ] + messages_sent_by_server = [ + b'hello from server', + b'hello2 from server', + b'long data from server ' + bytes(range(256)) * 100 + b'... server done', + ] + async def read_messages(transport, expected_messages): + ctr = 0 + async for msg in transport.read_messages(): + self.assertEqual(expected_messages[ctr], msg) + ctr += 1 + if ctr == len(expected_messages): + return + async def write_messages(transport, expected_messages): + for msg in expected_messages: + transport.send_bytes(msg) + await asyncio.sleep(0.01) + async def cb(reader, writer): t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer) self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes()) - t.send_bytes(b'hello from server') - self.assertEqual(await t.read_messages().__anext__(), b'hello from client') + async with TaskGroup() as group: + await group.spawn(read_messages(t, messages_sent_by_client)) + await group.spawn(write_messages(t, messages_sent_by_server)) responder_shaked.set() - server_future = asyncio.ensure_future(asyncio.start_server(cb, '127.0.0.1', 42898)) - loop.run_until_complete(server_future) - server = server_future.result() # type: asyncio.Server async def connect(): peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes()) t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None) await t.handshake() - t.send_bytes(b'hello from client') - self.assertEqual(await t.read_messages().__anext__(), b'hello from server') + async with TaskGroup() as group: + await group.spawn(read_messages(t, messages_sent_by_server)) + await group.spawn(write_messages(t, messages_sent_by_client)) server_shaked.set() - try: - connect_future = asyncio.ensure_future(connect()) - loop.run_until_complete(responder_shaked.wait()) - loop.run_until_complete(server_shaked.wait()) - finally: - server.close() - loop.run_until_complete(server.wait_closed()) + async def f(): + server = await asyncio.start_server(cb, '127.0.0.1', 42898) + try: + async with TaskGroup() as group: + await group.spawn(connect()) + await group.spawn(responder_shaked.wait()) + await group.spawn(server_shaked.wait()) + finally: + server.close() + await server.wait_closed() + + loop.run_until_complete(f())