Browse Source

tests: rework lntransport test a bit

send multiple messages, and not only short ones
patch-4
SomberNight 4 years ago
parent
commit
4f13c451c7
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 50
      electrum/tests/test_lntransport.py

50
electrum/tests/test_lntransport.py

@ -4,6 +4,8 @@ from electrum.ecc import ECPrivkey
from electrum.lnutil import LNPeerAddr from electrum.lnutil import LNPeerAddr
from electrum.lntransport import LNResponderTransport, LNTransport from electrum.lntransport import LNResponderTransport, LNTransport
from aiorpcx import TaskGroup
from . import ElectrumTestCase from . import ElectrumTestCase
from .test_bitcoin import needs_test_with_all_chacha20_implementations from .test_bitcoin import needs_test_with_all_chacha20_implementations
@ -46,27 +48,53 @@ class TestLNTransport(ElectrumTestCase):
server_shaked = asyncio.Event() server_shaked = asyncio.Event()
responder_key = ECPrivkey.generate_random_key() responder_key = ECPrivkey.generate_random_key()
initiator_key = ECPrivkey.generate_random_key() initiator_key = ECPrivkey.generate_random_key()
messages_sent_by_client = [
b'hello from client',
b'long data from client ' + bytes(range(256)) * 100 + b'... client done',
b'client is running out of things to say',
]
messages_sent_by_server = [
b'hello from server',
b'hello2 from server',
b'long data from server ' + bytes(range(256)) * 100 + b'... server done',
]
async def read_messages(transport, expected_messages):
ctr = 0
async for msg in transport.read_messages():
self.assertEqual(expected_messages[ctr], msg)
ctr += 1
if ctr == len(expected_messages):
return
async def write_messages(transport, expected_messages):
for msg in expected_messages:
transport.send_bytes(msg)
await asyncio.sleep(0.01)
async def cb(reader, writer): async def cb(reader, writer):
t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer) t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer)
self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes()) self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes())
t.send_bytes(b'hello from server') async with TaskGroup() as group:
self.assertEqual(await t.read_messages().__anext__(), b'hello from client') await group.spawn(read_messages(t, messages_sent_by_client))
await group.spawn(write_messages(t, messages_sent_by_server))
responder_shaked.set() responder_shaked.set()
server_future = asyncio.ensure_future(asyncio.start_server(cb, '127.0.0.1', 42898))
loop.run_until_complete(server_future)
server = server_future.result() # type: asyncio.Server
async def connect(): async def connect():
peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes()) peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes())
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None) t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None)
await t.handshake() await t.handshake()
t.send_bytes(b'hello from client') async with TaskGroup() as group:
self.assertEqual(await t.read_messages().__anext__(), b'hello from server') await group.spawn(read_messages(t, messages_sent_by_server))
await group.spawn(write_messages(t, messages_sent_by_client))
server_shaked.set() server_shaked.set()
async def f():
server = await asyncio.start_server(cb, '127.0.0.1', 42898)
try: try:
connect_future = asyncio.ensure_future(connect()) async with TaskGroup() as group:
loop.run_until_complete(responder_shaked.wait()) await group.spawn(connect())
loop.run_until_complete(server_shaked.wait()) await group.spawn(responder_shaked.wait())
await group.spawn(server_shaked.wait())
finally: finally:
server.close() server.close()
loop.run_until_complete(server.wait_closed()) await server.wait_closed()
loop.run_until_complete(f())

Loading…
Cancel
Save