diff --git a/electrum/channel_db.py b/electrum/channel_db.py index 3b8cd578c..89a8abbeb 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -31,6 +31,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK import binascii import base64 import asyncio +import threading from .sql_db import SqlDB, sql @@ -247,17 +248,21 @@ class ChannelDB(SqlDB): def __init__(self, network: 'Network'): path = os.path.join(get_headers_dir(network.config), 'gossip_db') super().__init__(network, path, commit_interval=100) + self.lock = threading.RLock() self.num_nodes = 0 self.num_channels = 0 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self.ca_verifier = LNChannelVerifier(network, self) + # initialized in load_data + # note: modify/iterate needs self.lock self._channels = {} # type: Dict[bytes, ChannelInfo] self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo # node_id -> (host, port, ts) self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]] self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] + self.data_loaded = asyncio.Event() self.network = network # only for callback @@ -268,16 +273,19 @@ class ChannelDB(SqlDB): self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies) def get_channel_ids(self): - return set(self._channels.keys()) + with self.lock: + return set(self._channels.keys()) def add_recent_peer(self, peer: LNPeerAddr): now = int(time.time()) node_id = peer.pubkey - self._addresses[node_id].add((peer.host, peer.port, now)) + with self.lock: + self._addresses[node_id].add((peer.host, peer.port, now)) self.save_node_address(node_id, peer, now) def get_200_randomly_sorted_nodes_not_in(self, node_ids): - unshuffled = set(self._nodes.keys()) - node_ids + with self.lock: + unshuffled = set(self._nodes.keys()) - node_ids return random.sample(unshuffled, min(200, len(unshuffled))) def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: @@ -296,8 +304,10 @@ class ChannelDB(SqlDB): # FIXME this does not reliably return "recent" peers... # Also, the list() cast over the whole dict (thousands of elements), # is really inefficient. + with self.lock: + _addresses_keys = list(self._addresses.keys()) r = [self.get_last_good_address(node_id) - for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]] + for node_id in _addresses_keys[-self.NUM_MAX_RECENT_PEERS:]] return list(reversed(r)) # note: currently channel announcements are trusted by default (trusted=True); @@ -336,9 +346,10 @@ class ChannelDB(SqlDB): except UnknownEvenFeatureBits: return channel_info = channel_info._replace(capacity_sat=capacity_sat) - self._channels[channel_info.short_channel_id] = channel_info - self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) - self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) + with self.lock: + self._channels[channel_info.short_channel_id] = channel_info + self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) + self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) if 'raw' in msg: self.save_channel(channel_info.short_channel_id, msg['raw']) @@ -397,7 +408,8 @@ class ChannelDB(SqlDB): if verify: self.verify_channel_update(payload) policy = Policy.from_msg(payload) - self._policies[key] = policy + with self.lock: + self._policies[key] = policy if 'raw' in payload: self.save_policy(policy.key, payload['raw']) # @@ -492,32 +504,38 @@ class ChannelDB(SqlDB): if node and node.timestamp >= node_info.timestamp: continue # save - self._nodes[node_id] = node_info + with self.lock: + self._nodes[node_id] = node_info if 'raw' in msg_payload: self.save_node_info(node_id, msg_payload['raw']) - for addr in node_addresses: - self._addresses[node_id].add((addr.host, addr.port, 0)) + with self.lock: + for addr in node_addresses: + self._addresses[node_id].add((addr.host, addr.port, 0)) self.save_node_addresses(node_id, node_addresses) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.update_counts() def get_old_policies(self, delta): + with self.lock: + _policies = self._policies.copy() now = int(time.time()) - return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) + return list(k for k, v in _policies.items() if v.timestamp <= now - delta) def prune_old_policies(self, delta): l = self.get_old_policies(delta) if l: for k in l: - self._policies.pop(k) + with self.lock: + self._policies.pop(k) self.delete_policy(*k) self.update_counts() self.logger.info(f'Deleting {len(l)} old policies') def get_orphaned_channels(self): - ids = set(x[1] for x in self._policies.keys()) - return list(x for x in self._channels.keys() if x not in ids) + with self.lock: + ids = set(x[1] for x in self._policies.keys()) + return list(x for x in self._channels.keys() if x not in ids) def prune_orphaned_channels(self): l = self.get_orphaned_channels() @@ -535,10 +553,11 @@ class ChannelDB(SqlDB): self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload def remove_channel(self, short_channel_id: ShortChannelID): - channel_info = self._channels.pop(short_channel_id, None) - if channel_info: - self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) - self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) + with self.lock: + channel_info = self._channels.pop(short_channel_id, None) + if channel_info: + self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) + self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) # delete from database self.delete_channel(short_channel_id) @@ -571,17 +590,19 @@ class ChannelDB(SqlDB): self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') self.update_counts() - self.count_incomplete_channels() + self.logger.info(f'semi-orphaned channels: {self.get_num_incomplete_channels()}') self.data_loaded.set() - def count_incomplete_channels(self): - out = set() - for short_channel_id, ci in self._channels.items(): + def get_num_incomplete_channels(self) -> int: + found = set() + with self.lock: + _channels = self._channels.copy() + for short_channel_id, ci in _channels.items(): p1 = self.get_policy_for_node(short_channel_id, ci.node1_id) p2 = self.get_policy_for_node(short_channel_id, ci.node2_id) if p1 is None or p2 is not None: - out.add(short_channel_id) - self.logger.info(f'semi-orphaned: {len(out)}') + found.add(short_channel_id) + return len(found) def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: