diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 286f4a343..c4c77fb09 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -4,6 +4,7 @@ # Distributed under the MIT software license, see the accompanying # file LICENCE or http://www.opensource.org/licenses/mit-license.php +import zlib from collections import OrderedDict, defaultdict import json import asyncio @@ -86,6 +87,7 @@ class Peer(PrintError): self.recv_commitment_for_ctn_last = defaultdict(lambda: None) # type: Dict[Channel, Optional[int]] self._local_changed_events = defaultdict(asyncio.Event) self._remote_changed_events = defaultdict(asyncio.Event) + self.receiving_channels = False def send_message(self, message_name: str, **kwargs): assert type(message_name) is str @@ -233,8 +235,18 @@ class Peer(PrintError): self.chan_anns = [] self.channel_db.on_channel_update(self.chan_upds) self.chan_upds = [] + need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int] + if need_to_get and not self.receiving_channels: + self.print_error('QUERYING SHORT CHANNEL IDS; ', len(need_to_get)) + zlibencoded = zlib.compress(b"".join(x.to_bytes(byteorder='big', length=8) for x in need_to_get)) + self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded) + self.receiving_channels = True + self.ping_if_required() + def on_reply_short_channel_ids_end(self, payload): + self.receiving_channels = False + def close_and_cleanup(self): try: if self.transport: diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index c4ae62343..0a2bd3775 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -30,7 +30,7 @@ import os import json import threading from collections import defaultdict -from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING +from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set import binascii import base64 import asyncio @@ -345,6 +345,10 @@ class ChannelDB: rows = DBSession.query(ChannelInfoInDB).filter(condition).all() return [bytes.fromhex(x.short_channel_id) for x in rows] + def missing_short_chan_ids(self) -> Set[int]: + expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfoInDB.short_channel_id))) + return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) + def add_verified_channel_info(self, short_id, capacity): # called from lnchannelverifier channel_info = self.get_channel_info(short_id)