From 85789d8a09523cd6c5635ec71ec03d099caf0c48 Mon Sep 17 00:00:00 2001 From: Janus Date: Thu, 25 Oct 2018 18:28:18 +0200 Subject: [PATCH] lnbase: mark initialized later, add tests, etc - consistent node_id sorting - require OPTION_DATA_LOSS_PROTECT and test it --- electrum/lnbase.py | 50 +++++++++++++++----------- electrum/tests/test_lnbase.py | 66 +++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 20 deletions(-) create mode 100644 electrum/tests/test_lnbase.py diff --git a/electrum/lnbase.py b/electrum/lnbase.py index 1e2edeadb..1f795b5dc 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -201,6 +201,7 @@ class Peer(PrintError): self.peer_addr = peer_addr self.lnworker = lnworker self.privkey = lnworker.node_keypair.privkey + self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)] self.network = lnworker.network self.lnwatcher = lnworker.network.lnwatcher self.channel_db = lnworker.network.channel_db @@ -218,7 +219,7 @@ class Peer(PrintError): self.localfeatures = LnLocalFeatures(0) if request_initial_sync: self.localfeatures |= LnLocalFeatures.INITIAL_ROUTING_SYNC - self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT + self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ self.attempted_route = {} self.orphan_channel_updates = OrderedDict() @@ -234,7 +235,6 @@ class Peer(PrintError): await transport.handshake() self.transport = transport self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures) - self.initialized.set_result(True) @property def channels(self) -> Dict[bytes, Channel]: @@ -310,6 +310,7 @@ class Peer(PrintError): raise LightningPeerConnectionClosed("remote does not have even flag {}" .format(str(LnLocalFeatures(1 << flag)))) self.localfeatures ^= 1 << flag # disable flag + self.initialized.set_result(True) def on_channel_update(self, payload): try: @@ -349,6 +350,13 @@ class Peer(PrintError): @log_exceptions @handle_disconnect async def main_loop(self): + """ + This is used from the GUI. It is not merged with the other function, + so that we can test if the correct exceptions are getting thrown. + """ + await self._main_loop() + + async def _main_loop(self): try: await asyncio.wait_for(self.initialize(), 10) except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: @@ -757,16 +765,17 @@ class Peer(PrintError): if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h): raise Exception("node_sig invalid in announcement_signatures") - node_sigs = [local_node_sig, remote_node_sig] - bitcoin_sigs = [local_bitcoin_sig, remote_bitcoin_sig] - node_ids = [privkey_to_pubkey(self.privkey), self.peer_addr.pubkey] - bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey] + node_sigs = [remote_node_sig, local_node_sig] + bitcoin_sigs = [remote_bitcoin_sig, local_bitcoin_sig] + bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, chan.config[LOCAL].multisig_key.pubkey] - if node_ids[0] > node_ids[1]: + if self.node_ids[0] > self.node_ids[1]: node_sigs.reverse() bitcoin_sigs.reverse() - node_ids.reverse() + node_ids = list(reversed(self.node_ids)) bitcoin_keys.reverse() + else: + node_ids = self.node_ids self.send_message("channel_announcement", node_signatures_1=node_sigs[0], @@ -793,14 +802,13 @@ class Peer(PrintError): chan.set_state("OPEN") self.network.trigger_callback('channel', chan) # add channel to database - pubkey_ours = self.lnworker.node_keypair.pubkey - pubkey_theirs = self.peer_addr.pubkey - node_ids = [pubkey_theirs, pubkey_ours] bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey] - sorted_node_ids = list(sorted(node_ids)) - if sorted_node_ids != node_ids: + sorted_node_ids = list(sorted(self.node_ids)) + if sorted_node_ids != self.node_ids: node_ids = sorted_node_ids bitcoin_keys.reverse() + else: + node_ids = self.node_ids # note: we inject a channel announcement, and a channel update (for outgoing direction) # This is atm needed for # - finding routes @@ -813,7 +821,10 @@ class Peer(PrintError): 'bitcoin_key_1': bitcoin_keys[0], 'bitcoin_key_2': bitcoin_keys[1]}, trusted=True) # only inject outgoing direction: - channel_flags = b'\x00' if node_ids[0] == pubkey_ours else b'\x01' + if node_ids[0] == privkey_to_pubkey(self.privkey): + channel_flags = b'\x00' + else: + channel_flags = b'\x01' now = int(time.time()).to_bytes(4, byteorder="big") self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'channel_flags': channel_flags, 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01', @@ -832,16 +843,15 @@ class Peer(PrintError): def send_announcement_signatures(self, chan): - bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, - chan.config[REMOTE].multisig_key.pubkey] - - node_ids = [privkey_to_pubkey(self.privkey), - self.peer_addr.pubkey] + bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, + chan.config[LOCAL].multisig_key.pubkey] - sorted_node_ids = list(sorted(node_ids)) + sorted_node_ids = list(sorted(self.node_ids)) if sorted_node_ids != node_ids: node_ids = sorted_node_ids bitcoin_keys.reverse() + else: + node_ids = self.node_ids chan_ann = gen_msg("channel_announcement", len=0, diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py new file mode 100644 index 000000000..074ec3f38 --- /dev/null +++ b/electrum/tests/test_lnbase.py @@ -0,0 +1,66 @@ +from electrum.lnbase import Peer, decode_msg, gen_msg +from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey +from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving +from electrum.ecc import ECPrivkey +from electrum.lnrouter import ChannelDB +import unittest +import asyncio +from electrum import simple_config +import tempfile +from .test_lnchan import create_test_channels + +class MockNetwork: + def __init__(self): + self.lnwatcher = None + user_config = {} + user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-") + self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir) + self.asyncio_loop = asyncio.get_event_loop() + self.channel_db = ChannelDB(self) + self.interface = None + def register_callback(self, cb, trigger_names): + print("callback registered", repr(trigger_names)) + def trigger_callback(self, trigger_name, obj): + print("callback triggered", repr(trigger_name)) + +class MockLNWorker: + def __init__(self, remote_peer_pubkey, chan): + self.chan = chan + self.remote_peer_pubkey = remote_peer_pubkey + priv = ECPrivkey.generate_random_key().get_secret_bytes() + self.node_keypair = Keypair( + pubkey=privkey_to_pubkey(priv), + privkey=priv) + self.network = MockNetwork() + @property + def peers(self): + return {self.remote_peer_pubkey: self.peer} + def channels_for_peer(self, pubkey): + return {self.chan.channel_id: self.chan} + +class MockTransport: + def __init__(self): + self.queue = asyncio.Queue() + async def read_messages(self): + while True: + yield await self.queue.get() + +class BadFeaturesTransport(MockTransport): + def send_bytes(self, data): + decoded = decode_msg(data) + print(decoded) + if decoded[0] == 'init': + self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) + +class TestPeer(unittest.TestCase): + def setUp(self): + self.alice_channel, self.bob_channel = create_test_channels() + def test_bad_feature_flags(self): + # we should require DATA_LOSS_PROTECT + mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel) + mock_transport = BadFeaturesTransport() + p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport) + mock_lnworker.peer = p1 + with self.assertRaises(LightningPeerConnectionClosed): + asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1)) +