diff --git a/electrum/lnbase.py b/electrum/lnbase.py index 4bfb36bfa..690821d7d 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -350,6 +350,11 @@ class Peer(PrintError): @log_exceptions @handle_disconnect async def main_loop(self): + """ + This is used in LNWorker and is necessary so that we don't kill the main + task group. It is not merged with _main_loop, so that we can test if the + correct exceptions are getting thrown using _main_loop. + """ await self._main_loop() async def _main_loop(self): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 65049b9da..5640a6aa8 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -32,7 +32,6 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, generate_keypair, LnKeyFamily, LOCAL, REMOTE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, NUM_MAX_EDGES_IN_PAYMENT_PATH) -from .lnaddr import lndecode from .i18n import _ from .lnrouter import RouteEdge, is_route_sane_to_use @@ -258,6 +257,15 @@ class LNWorker(PrintError): return bh2u(chan.node_id) def pay(self, invoice, amount_sat=None): + """ + This is not merged with _pay so that we can run the test with + one thread only. + """ + addr, peer, coro = self._pay(invoice, amount_sat) + fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + return addr, peer, fut + + def _pay(self, invoice, amount_sat=None): addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) payment_hash = addr.paymenthash amount_sat = (addr.amount * COIN) if addr.amount else amount_sat @@ -279,7 +287,7 @@ class LNWorker(PrintError): raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) peer = self.peers[node_id] coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry()) - return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + return addr, peer, coro def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]: invoice_pubkey = decoded_invoice.pubkey.serialize() diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py index 074ec3f38..c64e1ced0 100644 --- a/electrum/tests/test_lnbase.py +++ b/electrum/tests/test_lnbase.py @@ -1,16 +1,40 @@ -from electrum.lnbase import Peer, decode_msg, gen_msg -from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey -from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving -from electrum.ecc import ECPrivkey -from electrum.lnrouter import ChannelDB import unittest import asyncio -from electrum import simple_config import tempfile +from decimal import Decimal +import os +from contextlib import contextmanager +from collections import defaultdict + +from electrum.network import Network +from electrum.ecc import ECPrivkey +from electrum import simple_config, lnutil +from electrum.lnaddr import lnencode, LnAddr, lndecode +from electrum.bitcoin import COIN, sha256 +from electrum.util import bh2u + +from electrum.lnbase import Peer, decode_msg, gen_msg +from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey +from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving +from electrum.lnrouter import ChannelDB, LNPathFinder +from electrum.lnworker import LNWorker + from .test_lnchan import create_test_channels +def keypair(): + priv = ECPrivkey.generate_random_key().get_secret_bytes() + k1 = Keypair( + pubkey=privkey_to_pubkey(priv), + privkey=priv) + return k1 + +@contextmanager +def noop_lock(): + yield + class MockNetwork: def __init__(self): + self.callbacks = defaultdict(list) self.lnwatcher = None user_config = {} user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-") @@ -18,49 +42,132 @@ class MockNetwork: self.asyncio_loop = asyncio.get_event_loop() self.channel_db = ChannelDB(self) self.interface = None - def register_callback(self, cb, trigger_names): - print("callback registered", repr(trigger_names)) - def trigger_callback(self, trigger_name, obj): - print("callback triggered", repr(trigger_name)) + self.path_finder = LNPathFinder(self.channel_db) + + @property + def callback_lock(self): + return noop_lock() + + register_callback = Network.register_callback + unregister_callback = Network.unregister_callback + trigger_callback = Network.trigger_callback + + def get_local_height(self): + return 0 class MockLNWorker: - def __init__(self, remote_peer_pubkey, chan): + def __init__(self, remote_keypair, local_keypair, chan): self.chan = chan - self.remote_peer_pubkey = remote_peer_pubkey - priv = ECPrivkey.generate_random_key().get_secret_bytes() - self.node_keypair = Keypair( - pubkey=privkey_to_pubkey(priv), - privkey=priv) + self.remote_keypair = remote_keypair + self.node_keypair = local_keypair self.network = MockNetwork() + self.channels = {self.chan.channel_id: self.chan} + self.invoices = {} + + @property + def lock(self): + return noop_lock() + @property def peers(self): - return {self.remote_peer_pubkey: self.peer} + return {self.remote_keypair.pubkey: self.peer} + def channels_for_peer(self, pubkey): - return {self.chan.channel_id: self.chan} + return self.channels + + def save_channel(self, chan): + pass + + get_invoice = LNWorker.get_invoice + _create_route_from_invoice = LNWorker._create_route_from_invoice class MockTransport: def __init__(self): self.queue = asyncio.Queue() + async def read_messages(self): while True: yield await self.queue.get() -class BadFeaturesTransport(MockTransport): +class NoFeaturesTransport(MockTransport): + """ + This answers the init message with a init that doesn't signal any features. + Used for testing that we require DATA_LOSS_PROTECT. + """ def send_bytes(self, data): decoded = decode_msg(data) print(decoded) if decoded[0] == 'init': self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) +class PutIntoOthersQueueTransport(MockTransport): + def __init__(self): + super().__init__() + self.other_mock_transport = None + + def send_bytes(self, data): + self.other_mock_transport.queue.put_nowait(data) + +def transport_pair(): + t1 = PutIntoOthersQueueTransport() + t2 = PutIntoOthersQueueTransport() + t1.other_mock_transport = t2 + t2.other_mock_transport = t1 + return t1, t2 + class TestPeer(unittest.TestCase): def setUp(self): self.alice_channel, self.bob_channel = create_test_channels() - def test_bad_feature_flags(self): - # we should require DATA_LOSS_PROTECT - mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel) - mock_transport = BadFeaturesTransport() - p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport) + + def test_require_data_loss_protect(self): + mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel) + mock_transport = NoFeaturesTransport() + p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport) mock_lnworker.peer = p1 with self.assertRaises(LightningPeerConnectionClosed): asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1)) + def test_payment(self): + k1, k2 = keypair(), keypair() + t1, t2 = transport_pair() + w1 = MockLNWorker(k1, k2, self.alice_channel) + w2 = MockLNWorker(k2, k1, self.bob_channel) + p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey), + request_initial_sync=False, transport=t1) + p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey), + request_initial_sync=False, transport=t2) + w1.peer = p1 + w2.peer = p2 + # mark_open won't work if state is already OPEN. + # so set it to OPENING + self.alice_channel.set_state("OPENING") + self.bob_channel.set_state("OPENING") + # this populates the channel graph: + p1.mark_open(self.alice_channel) + p2.mark_open(self.bob_channel) + amount_btc = 100000/Decimal(COIN) + payment_preimage = os.urandom(32) + RHASH = sha256(payment_preimage) + addr = LnAddr( + RHASH, + amount_btc, + tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), + ('d', 'coffee') + ]) + pay_req = lnencode(addr, w2.node_keypair.privkey) + w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req) + l = asyncio.get_event_loop() + async def pay(): + fut = asyncio.Future() + def evt_set(event, _lnworker, msg): + fut.set_result(msg) + w2.network.register_callback(evt_set, ['ln_message']) + + addr, peer, coro = LNWorker._pay(w1, pay_req) + await coro + print("HTLC ADDED") + self.assertEqual(await fut, 'Payment received') + gath.cancel() + gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop()) + with self.assertRaises(asyncio.CancelledError): + l.run_until_complete(gath)