diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index c4c77fb09..d79372627 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -59,7 +59,6 @@ class Peer(PrintError): self.node_anns = [] self.chan_anns = [] self.chan_upds = [] - self.last_chan_db_upd = time.time() self.transport = transport self.pubkey = pubkey self.lnworker = lnworker @@ -209,15 +208,31 @@ class Peer(PrintError): @log_exceptions @handle_disconnect async def main_loop(self): - """ - This is used in LNWorker and is necessary so that we don't kill the main - task group. It is not merged with _main_loop, so that we can test if the - correct exceptions are getting thrown using _main_loop. - """ - await self._main_loop() + async with aiorpcx.TaskGroup() as group: + await group.spawn(self._gossip_loop()) + await group.spawn(self._message_loop()) - async def _main_loop(self): - """This is separate from main_loop for the tests.""" + async def _gossip_loop(self): + await self.initialized.wait() + while True: + await asyncio.sleep(5) + if self.node_anns: + self.channel_db.on_node_announcement(self.node_anns) + self.node_anns = [] + if self.chan_anns: + self.channel_db.on_channel_announcement(self.chan_anns) + self.chan_anns = [] + if self.chan_upds: + self.channel_db.on_channel_update(self.chan_upds) + self.chan_upds = [] + need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int] + if need_to_get and not self.receiving_channels: + self.print_error('QUERYING SHORT CHANNEL IDS; missing', len(need_to_get), 'channels') + zlibencoded = zlib.compress(bfh(''.join(need_to_get))) + self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded) + self.receiving_channels = True + + async def _message_loop(self): try: await asyncio.wait_for(self.initialize(), 10) except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: @@ -227,21 +242,6 @@ class Peer(PrintError): async for msg in self.transport.read_messages(): self.process_message(msg) await asyncio.sleep(.01) - if time.time() - self.last_chan_db_upd > 5: - self.last_chan_db_upd = time.time() - self.channel_db.on_node_announcement(self.node_anns) - self.node_anns = [] - self.channel_db.on_channel_announcement(self.chan_anns) - self.chan_anns = [] - self.channel_db.on_channel_update(self.chan_upds) - self.chan_upds = [] - need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int] - if need_to_get and not self.receiving_channels: - self.print_error('QUERYING SHORT CHANNEL IDS; ', len(need_to_get)) - zlibencoded = zlib.compress(b"".join(x.to_bytes(byteorder='big', length=8) for x in need_to_get)) - self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded) - self.receiving_channels = True - self.ping_if_required() def on_reply_short_channel_ids_end(self, payload): diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 61235abc8..850f8c6f9 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -347,7 +347,19 @@ class ChannelDB: def missing_short_chan_ids(self) -> Set[int]: expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id))) - return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) + chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all()) + if chan_ids_from_policy: + return chan_ids_from_policy + # fetch channels for node_ids missing in node_info. that will also give us node_announcement + expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id))) + chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) + if chan_ids_from_id1: + return chan_ids_from_id1 + expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id))) + chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) + if chan_ids_from_id2: + return chan_ids_from_id2 + return set() def add_verified_channel_info(self, short_id, capacity): # called from lnchannelverifier @@ -390,6 +402,8 @@ class ChannelDB: if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: continue channel_info = channel_infos.get(short_channel_id) + if not channel_info: + continue channel_info.on_channel_update(msg_payload, trusted=trusted) DBSession.commit() diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 8edaf6f67..de0ee5c10 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -173,7 +173,7 @@ class TestPeer(unittest.TestCase): p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False) mock_lnworker.peer = p1 with self.assertRaises(LightningPeerConnectionClosed): - run(asyncio.wait_for(p1._main_loop(), 1)) + run(asyncio.wait_for(p1._message_loop(), 1)) def prepare_peers(self): k1, k2 = keypair(), keypair() @@ -231,7 +231,7 @@ class TestPeer(unittest.TestCase): print("HTLC ADDED") self.assertEqual(await fut, 'Payment received') gath.cancel() - gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop()) + gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) with self.assertRaises(asyncio.CancelledError): run(gath) @@ -254,7 +254,7 @@ class TestPeer(unittest.TestCase): # AssertionError is ok since we shouldn't use old routes, and the # route finding should fail when channel is closed with self.assertRaises(AssertionError): - run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._main_loop(), p2._main_loop())) + run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop())) def run(coro): return asyncio.get_event_loop().run_until_complete(coro)