diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 96ff6df1d..daad1aa6b 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -57,9 +57,7 @@ class Peer(Logger): def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase): self.initialized = asyncio.Event() - self.node_anns = [] - self.chan_anns = [] - self.chan_upds = [] + self.querying_lock = asyncio.Lock() self.transport = transport self.pubkey = pubkey self.lnworker = lnworker @@ -70,6 +68,7 @@ class Peer(Logger): self.lnwatcher = lnworker.network.lnwatcher self.channel_db = lnworker.network.channel_db self.ping_time = 0 + self.reply_channel_range = asyncio.Queue() self.shutdown_received = defaultdict(asyncio.Future) self.channel_accepted = defaultdict(asyncio.Queue) self.channel_reestablished = defaultdict(asyncio.Future) @@ -89,7 +88,7 @@ class Peer(Logger): def send_message(self, message_name: str, **kwargs): assert type(message_name) is str - self.logger.info(f"Sending {message_name.upper()}") + self.logger.debug(f"Sending {message_name.upper()}") self.transport.send_bytes(encode_msg(message_name, **kwargs)) async def initialize(self): @@ -177,13 +176,13 @@ class Peer(Logger): self.initialized.set() def on_node_announcement(self, payload): - self.node_anns.append(payload) + self.channel_db.node_anns.append(payload) def on_channel_update(self, payload): - self.chan_upds.append(payload) + self.channel_db.chan_upds.append(payload) def on_channel_announcement(self, payload): - self.chan_anns.append(payload) + self.channel_db.chan_anns.append(payload) def on_announcement_signatures(self, payload): channel_id = payload['channel_id'] @@ -207,15 +206,11 @@ class Peer(Logger): @handle_disconnect async def main_loop(self): async with aiorpcx.TaskGroup() as group: - await group.spawn(self._gossip_loop()) await group.spawn(self._message_loop()) # kill group if the peer times out await group.spawn(asyncio.wait_for(self.initialized.wait(), 10)) - @log_exceptions - async def _gossip_loop(self): - await self.initialized.wait() - timestamp = self.channel_db.get_last_timestamp() + def request_gossip(self, timestamp=0): if timestamp == 0: self.logger.info('requesting whole channel graph') else: @@ -225,28 +220,47 @@ class Peer(Logger): chain_hash=constants.net.rev_genesis_bytes(), first_timestamp=timestamp, timestamp_range=b'\xff'*4) - while True: - await asyncio.sleep(5) - if self.node_anns: - self.channel_db.on_node_announcement(self.node_anns) - self.node_anns = [] - if self.chan_anns: - self.channel_db.on_channel_announcement(self.chan_anns) - self.chan_anns = [] - if self.chan_upds: - self.channel_db.on_channel_update(self.chan_upds) - self.chan_upds = [] - # todo: enable when db is fixed - #need_to_get = sorted(self.channel_db.missing_short_chan_ids()) - #if need_to_get and not self.receiving_channels: - # self.logger.info(f'missing {len(need_to_get)} channels') - # zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100]))) - # self.send_message( - # 'query_short_channel_ids', - # chain_hash=constants.net.rev_genesis_bytes(), - # len=1+len(zlibencoded), - # encoded_short_ids=b'\x01' + zlibencoded) - # self.receiving_channels = True + + def query_channel_range(self, index, num): + self.logger.info(f'query channel range') + self.send_message( + 'query_channel_range', + chain_hash=constants.net.rev_genesis_bytes(), + first_blocknum=index, + number_of_blocks=num) + + def encode_short_ids(self, ids): + return chr(1) + zlib.compress(bfh(''.join(ids))) + + def decode_short_ids(self, encoded): + if encoded[0] == 0: + decoded = encoded[1:] + elif encoded[0] == 1: + decoded = zlib.decompress(encoded[1:]) + else: + raise BaseException('zlib') + ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)] + return ids + + def on_reply_channel_range(self, payload): + first = int.from_bytes(payload['first_blocknum'], 'big') + num = int.from_bytes(payload['number_of_blocks'], 'big') + complete = bool(payload['complete']) + encoded = payload['encoded_short_ids'] + ids = self.decode_short_ids(encoded) + self.reply_channel_range.put_nowait((first, num, complete, ids)) + + async def query_short_channel_ids(self, ids, compressed=True): + await self.querying_lock.acquire() + #self.logger.info('querying {} short_channel_ids'.format(len(ids))) + s = b''.join(ids) + encoded = zlib.compress(s) if compressed else s + prefix = b'\x01' if compressed else b'\x00' + self.send_message( + 'query_short_channel_ids', + chain_hash=constants.net.rev_genesis_bytes(), + len=1+len(encoded), + encoded_short_ids=prefix+encoded) async def _message_loop(self): try: @@ -260,7 +274,7 @@ class Peer(Logger): self.ping_if_required() def on_reply_short_channel_ids_end(self, payload): - self.receiving_channels = False + self.querying_lock.release() def close_and_cleanup(self): try: diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 50f814ce4..f197a8eaf 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -223,6 +223,20 @@ class ChannelDB(SqlDB): self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self.ca_verifier = LNChannelVerifier(network, self) self.update_counts() + self.node_anns = [] + self.chan_anns = [] + self.chan_upds = [] + + def process_gossip(self): + if self.node_anns: + self.on_node_announcement(self.node_anns) + self.node_anns = [] + if self.chan_anns: + self.on_channel_announcement(self.chan_anns) + self.chan_anns = [] + if self.chan_upds: + self.on_channel_update(self.chan_upds) + self.chan_upds = [] @sql def update_counts(self): @@ -232,7 +246,32 @@ class ChannelDB(SqlDB): self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_policies = self.DBSession.query(Policy).count() self.num_nodes = self.DBSession.query(NodeInfo).count() - self.logger.info(f'update counts {self.num_channels} {self.num_policies}') + + @sql + @profiler + def purge_unknown_channels(self, channel_ids): + ids = [x.hex() for x in channel_ids] + missing = self.DBSession \ + .query(ChannelInfo) \ + .filter(not_(ChannelInfo.short_channel_id.in_(ids))) \ + .all() + if missing: + self.logger.info("deleting {} channels".format(len(missing))) + delete_query = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(ids))) + self.DBSession.execute(delete_query) + self.DBSession.commit() + + @sql + @profiler + def compare_channels(self, channel_ids): + ids = [x.hex() for x in channel_ids] + # I need to get the unknown, and also the channels that need refresh + known = self.DBSession \ + .query(ChannelInfo) \ + .filter(ChannelInfo.short_channel_id.in_(ids)) \ + .all() + known = [bfh(r.short_channel_id) for r in known] + return known @sql def add_recent_peer(self, peer: LNPeerAddr): @@ -276,12 +315,14 @@ class ChannelDB(SqlDB): return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] @sql - def missing_short_chan_ids(self) -> Set[int]: + def missing_channel_announcements(self) -> Set[int]: expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) - chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) - if chan_ids_from_policy: - return chan_ids_from_policy - return set() + return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) + + @sql + def missing_channel_updates(self) -> Set[int]: + expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id))) + return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) @sql def add_verified_channel_info(self, short_id, capacity): @@ -316,8 +357,8 @@ class ChannelDB(SqlDB): for channel_info in new_channels.values(): self.DBSession.add(channel_info) self.DBSession.commit() - #self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) self._update_counts() + self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) self.network.trigger_callback('ln_status') @sql @@ -370,7 +411,7 @@ class ChannelDB(SqlDB): self.DBSession.commit() if new_policies: self.logger.info(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}') - self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}') + #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}') self._update_counts() @sql diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 37402cf26..1a9e2bd5c 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -133,9 +133,7 @@ class LNWorker(Logger): self.channel_db = self.network.channel_db self._last_tried_peer = {} # LNPeerAddr -> unix timestamp self._add_peers_from_config() - # wait until we see confirmations asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop) - self.first_timestamp_requested = None def _add_peers_from_config(self): peer_list = self.config.get('lightning_peers', []) @@ -215,9 +213,24 @@ class LNWorker(Logger): self.logger.info('got {} ln peers from dns seed'.format(len(peers))) return peers + @staticmethod + def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]: + assert len(addr_list) >= 1 + # choose first one that is an IP + for addr_in_db in addr_list: + host = addr_in_db.host + port = addr_in_db.port + if is_ip_address(host): + return host, port + # otherwise choose one at random + # TODO maybe filter out onion if not on tor? + choice = random.choice(addr_list) + return choice.host, choice.port class LNGossip(LNWorker): + # height of first channel announcements + first_block = 497000 def __init__(self, network): seed = os.urandom(32) @@ -226,6 +239,61 @@ class LNGossip(LNWorker): super().__init__(xprv) self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ + def start_network(self, network: 'Network'): + super().start_network(network) + asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.gossip_task()), self.network.asyncio_loop) + + async def gossip_task(self): + req_index = self.first_block + req_num = self.network.get_local_height() - req_index + while len(self.peers) == 0: + await asyncio.sleep(1) + continue + # todo: parallelize over peers + peer = list(self.peers.values())[0] + await peer.initialized.wait() + # send channels_range query. peer will reply with several intervals + peer.query_channel_range(req_index, req_num) + intervals = [] + ids = set() + # wait until requested range is covered + while True: + index, num, complete, _ids = await peer.reply_channel_range.get() + ids.update(_ids) + intervals.append((index, index+num)) + intervals.sort() + while len(intervals) > 1: + a,b = intervals[0] + c,d = intervals[1] + if b == c: + intervals = [(a,d)] + intervals[2:] + else: + break + if len(intervals) == 1: + a, b = intervals[0] + if a <= req_index and b >= req_index + req_num: + break + self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete)) + # TODO: filter results by date of last channel update, purge DB + #if complete: + # self.channel_db.purge_unknown_channels(ids) + known = self.channel_db.compare_channels(ids) + unknown = list(ids - set(known)) + total = len(unknown) + N = 500 + while unknown: + self.channel_db.process_gossip() + await peer.query_short_channel_ids(unknown[0:N]) + unknown = unknown[N:] + self.logger.info(f'Querying channels: {total - len(unknown)}/{total}. Count: {self.channel_db.num_channels}') + + # request gossip fromm current time + now = int(time.time()) + peer.request_gossip(now) + while True: + await asyncio.sleep(5) + self.channel_db.process_gossip() + class LNWallet(LNWorker): @@ -548,20 +616,6 @@ class LNWallet(LNWorker): def on_channels_updated(self): self.network.trigger_callback('channels') - @staticmethod - def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]: - assert len(addr_list) >= 1 - # choose first one that is an IP - for addr_in_db in addr_list: - host = addr_in_db.host - port = addr_in_db.port - if is_ip_address(host): - return host, port - # otherwise choose one at random - # TODO maybe filter out onion if not on tor? - choice = random.choice(addr_list) - return choice.host, choice.port - def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=20): node_id, rest = extract_nodeid(connect_contents) peer = self.peers.get(node_id)