Browse Source

ChannelDB: add self.lock and make it thread-safe

hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
fd56fb9189
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 71
      electrum/channel_db.py

71
electrum/channel_db.py

@ -31,6 +31,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
import asyncio
import threading
from .sql_db import SqlDB, sql
@ -247,17 +248,21 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'gossip_db')
super().__init__(network, path, commit_interval=100)
self.lock = threading.RLock()
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data
# note: modify/iterate needs self.lock
self._channels = {} # type: Dict[bytes, ChannelInfo]
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
# node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
self.data_loaded = asyncio.Event()
self.network = network # only for callback
@ -268,16 +273,19 @@ class ChannelDB(SqlDB):
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
def get_channel_ids(self):
return set(self._channels.keys())
with self.lock:
return set(self._channels.keys())
def add_recent_peer(self, peer: LNPeerAddr):
now = int(time.time())
node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now))
with self.lock:
self._addresses[node_id].add((peer.host, peer.port, now))
self.save_node_address(node_id, peer, now)
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids
with self.lock:
unshuffled = set(self._nodes.keys()) - node_ids
return random.sample(unshuffled, min(200, len(unshuffled)))
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
@ -296,8 +304,10 @@ class ChannelDB(SqlDB):
# FIXME this does not reliably return "recent" peers...
# Also, the list() cast over the whole dict (thousands of elements),
# is really inefficient.
with self.lock:
_addresses_keys = list(self._addresses.keys())
r = [self.get_last_good_address(node_id)
for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]]
for node_id in _addresses_keys[-self.NUM_MAX_RECENT_PEERS:]]
return list(reversed(r))
# note: currently channel announcements are trusted by default (trusted=True);
@ -336,9 +346,10 @@ class ChannelDB(SqlDB):
except UnknownEvenFeatureBits:
return
channel_info = channel_info._replace(capacity_sat=capacity_sat)
self._channels[channel_info.short_channel_id] = channel_info
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
with self.lock:
self._channels[channel_info.short_channel_id] = channel_info
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
if 'raw' in msg:
self.save_channel(channel_info.short_channel_id, msg['raw'])
@ -397,7 +408,8 @@ class ChannelDB(SqlDB):
if verify:
self.verify_channel_update(payload)
policy = Policy.from_msg(payload)
self._policies[key] = policy
with self.lock:
self._policies[key] = policy
if 'raw' in payload:
self.save_policy(policy.key, payload['raw'])
#
@ -492,32 +504,38 @@ class ChannelDB(SqlDB):
if node and node.timestamp >= node_info.timestamp:
continue
# save
self._nodes[node_id] = node_info
with self.lock:
self._nodes[node_id] = node_info
if 'raw' in msg_payload:
self.save_node_info(node_id, msg_payload['raw'])
for addr in node_addresses:
self._addresses[node_id].add((addr.host, addr.port, 0))
with self.lock:
for addr in node_addresses:
self._addresses[node_id].add((addr.host, addr.port, 0))
self.save_node_addresses(node_id, node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
def get_old_policies(self, delta):
with self.lock:
_policies = self._policies.copy()
now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
def prune_old_policies(self, delta):
l = self.get_old_policies(delta)
if l:
for k in l:
self._policies.pop(k)
with self.lock:
self._policies.pop(k)
self.delete_policy(*k)
self.update_counts()
self.logger.info(f'Deleting {len(l)} old policies')
def get_orphaned_channels(self):
ids = set(x[1] for x in self._policies.keys())
return list(x for x in self._channels.keys() if x not in ids)
with self.lock:
ids = set(x[1] for x in self._policies.keys())
return list(x for x in self._channels.keys() if x not in ids)
def prune_orphaned_channels(self):
l = self.get_orphaned_channels()
@ -535,10 +553,11 @@ class ChannelDB(SqlDB):
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id: ShortChannelID):
channel_info = self._channels.pop(short_channel_id, None)
if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
with self.lock:
channel_info = self._channels.pop(short_channel_id, None)
if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
# delete from database
self.delete_channel(short_channel_id)
@ -571,17 +590,19 @@ class ChannelDB(SqlDB):
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
self.update_counts()
self.count_incomplete_channels()
self.logger.info(f'semi-orphaned channels: {self.get_num_incomplete_channels()}')
self.data_loaded.set()
def count_incomplete_channels(self):
out = set()
for short_channel_id, ci in self._channels.items():
def get_num_incomplete_channels(self) -> int:
found = set()
with self.lock:
_channels = self._channels.copy()
for short_channel_id, ci in _channels.items():
p1 = self.get_policy_for_node(short_channel_id, ci.node1_id)
p2 = self.get_policy_for_node(short_channel_id, ci.node2_id)
if p1 is None or p2 is not None:
out.add(short_channel_id)
self.logger.info(f'semi-orphaned: {len(out)}')
found.add(short_channel_id)
return len(found)
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:

Loading…
Cancel
Save