|
|
@ -31,6 +31,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK |
|
|
|
import binascii |
|
|
|
import base64 |
|
|
|
import asyncio |
|
|
|
import threading |
|
|
|
|
|
|
|
|
|
|
|
from .sql_db import SqlDB, sql |
|
|
@ -247,17 +248,21 @@ class ChannelDB(SqlDB): |
|
|
|
def __init__(self, network: 'Network'): |
|
|
|
path = os.path.join(get_headers_dir(network.config), 'gossip_db') |
|
|
|
super().__init__(network, path, commit_interval=100) |
|
|
|
self.lock = threading.RLock() |
|
|
|
self.num_nodes = 0 |
|
|
|
self.num_channels = 0 |
|
|
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] |
|
|
|
self.ca_verifier = LNChannelVerifier(network, self) |
|
|
|
|
|
|
|
# initialized in load_data |
|
|
|
# note: modify/iterate needs self.lock |
|
|
|
self._channels = {} # type: Dict[bytes, ChannelInfo] |
|
|
|
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy |
|
|
|
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo |
|
|
|
# node_id -> (host, port, ts) |
|
|
|
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]] |
|
|
|
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] |
|
|
|
|
|
|
|
self.data_loaded = asyncio.Event() |
|
|
|
self.network = network # only for callback |
|
|
|
|
|
|
@ -268,16 +273,19 @@ class ChannelDB(SqlDB): |
|
|
|
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies) |
|
|
|
|
|
|
|
def get_channel_ids(self): |
|
|
|
return set(self._channels.keys()) |
|
|
|
with self.lock: |
|
|
|
return set(self._channels.keys()) |
|
|
|
|
|
|
|
def add_recent_peer(self, peer: LNPeerAddr): |
|
|
|
now = int(time.time()) |
|
|
|
node_id = peer.pubkey |
|
|
|
self._addresses[node_id].add((peer.host, peer.port, now)) |
|
|
|
with self.lock: |
|
|
|
self._addresses[node_id].add((peer.host, peer.port, now)) |
|
|
|
self.save_node_address(node_id, peer, now) |
|
|
|
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids): |
|
|
|
unshuffled = set(self._nodes.keys()) - node_ids |
|
|
|
with self.lock: |
|
|
|
unshuffled = set(self._nodes.keys()) - node_ids |
|
|
|
return random.sample(unshuffled, min(200, len(unshuffled))) |
|
|
|
|
|
|
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: |
|
|
@ -296,8 +304,10 @@ class ChannelDB(SqlDB): |
|
|
|
# FIXME this does not reliably return "recent" peers... |
|
|
|
# Also, the list() cast over the whole dict (thousands of elements), |
|
|
|
# is really inefficient. |
|
|
|
with self.lock: |
|
|
|
_addresses_keys = list(self._addresses.keys()) |
|
|
|
r = [self.get_last_good_address(node_id) |
|
|
|
for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]] |
|
|
|
for node_id in _addresses_keys[-self.NUM_MAX_RECENT_PEERS:]] |
|
|
|
return list(reversed(r)) |
|
|
|
|
|
|
|
# note: currently channel announcements are trusted by default (trusted=True); |
|
|
@ -336,9 +346,10 @@ class ChannelDB(SqlDB): |
|
|
|
except UnknownEvenFeatureBits: |
|
|
|
return |
|
|
|
channel_info = channel_info._replace(capacity_sat=capacity_sat) |
|
|
|
self._channels[channel_info.short_channel_id] = channel_info |
|
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
|
|
with self.lock: |
|
|
|
self._channels[channel_info.short_channel_id] = channel_info |
|
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
|
|
if 'raw' in msg: |
|
|
|
self.save_channel(channel_info.short_channel_id, msg['raw']) |
|
|
|
|
|
|
@ -397,7 +408,8 @@ class ChannelDB(SqlDB): |
|
|
|
if verify: |
|
|
|
self.verify_channel_update(payload) |
|
|
|
policy = Policy.from_msg(payload) |
|
|
|
self._policies[key] = policy |
|
|
|
with self.lock: |
|
|
|
self._policies[key] = policy |
|
|
|
if 'raw' in payload: |
|
|
|
self.save_policy(policy.key, payload['raw']) |
|
|
|
# |
|
|
@ -492,32 +504,38 @@ class ChannelDB(SqlDB): |
|
|
|
if node and node.timestamp >= node_info.timestamp: |
|
|
|
continue |
|
|
|
# save |
|
|
|
self._nodes[node_id] = node_info |
|
|
|
with self.lock: |
|
|
|
self._nodes[node_id] = node_info |
|
|
|
if 'raw' in msg_payload: |
|
|
|
self.save_node_info(node_id, msg_payload['raw']) |
|
|
|
for addr in node_addresses: |
|
|
|
self._addresses[node_id].add((addr.host, addr.port, 0)) |
|
|
|
with self.lock: |
|
|
|
for addr in node_addresses: |
|
|
|
self._addresses[node_id].add((addr.host, addr.port, 0)) |
|
|
|
self.save_node_addresses(node_id, node_addresses) |
|
|
|
|
|
|
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) |
|
|
|
self.update_counts() |
|
|
|
|
|
|
|
def get_old_policies(self, delta): |
|
|
|
with self.lock: |
|
|
|
_policies = self._policies.copy() |
|
|
|
now = int(time.time()) |
|
|
|
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) |
|
|
|
return list(k for k, v in _policies.items() if v.timestamp <= now - delta) |
|
|
|
|
|
|
|
def prune_old_policies(self, delta): |
|
|
|
l = self.get_old_policies(delta) |
|
|
|
if l: |
|
|
|
for k in l: |
|
|
|
self._policies.pop(k) |
|
|
|
with self.lock: |
|
|
|
self._policies.pop(k) |
|
|
|
self.delete_policy(*k) |
|
|
|
self.update_counts() |
|
|
|
self.logger.info(f'Deleting {len(l)} old policies') |
|
|
|
|
|
|
|
def get_orphaned_channels(self): |
|
|
|
ids = set(x[1] for x in self._policies.keys()) |
|
|
|
return list(x for x in self._channels.keys() if x not in ids) |
|
|
|
with self.lock: |
|
|
|
ids = set(x[1] for x in self._policies.keys()) |
|
|
|
return list(x for x in self._channels.keys() if x not in ids) |
|
|
|
|
|
|
|
def prune_orphaned_channels(self): |
|
|
|
l = self.get_orphaned_channels() |
|
|
@ -535,10 +553,11 @@ class ChannelDB(SqlDB): |
|
|
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload |
|
|
|
|
|
|
|
def remove_channel(self, short_channel_id: ShortChannelID): |
|
|
|
channel_info = self._channels.pop(short_channel_id, None) |
|
|
|
if channel_info: |
|
|
|
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) |
|
|
|
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) |
|
|
|
with self.lock: |
|
|
|
channel_info = self._channels.pop(short_channel_id, None) |
|
|
|
if channel_info: |
|
|
|
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) |
|
|
|
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) |
|
|
|
# delete from database |
|
|
|
self.delete_channel(short_channel_id) |
|
|
|
|
|
|
@ -571,17 +590,19 @@ class ChannelDB(SqlDB): |
|
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
|
|
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') |
|
|
|
self.update_counts() |
|
|
|
self.count_incomplete_channels() |
|
|
|
self.logger.info(f'semi-orphaned channels: {self.get_num_incomplete_channels()}') |
|
|
|
self.data_loaded.set() |
|
|
|
|
|
|
|
def count_incomplete_channels(self): |
|
|
|
out = set() |
|
|
|
for short_channel_id, ci in self._channels.items(): |
|
|
|
def get_num_incomplete_channels(self) -> int: |
|
|
|
found = set() |
|
|
|
with self.lock: |
|
|
|
_channels = self._channels.copy() |
|
|
|
for short_channel_id, ci in _channels.items(): |
|
|
|
p1 = self.get_policy_for_node(short_channel_id, ci.node1_id) |
|
|
|
p2 = self.get_policy_for_node(short_channel_id, ci.node2_id) |
|
|
|
if p1 is None or p2 is not None: |
|
|
|
out.add(short_channel_id) |
|
|
|
self.logger.info(f'semi-orphaned: {len(out)}') |
|
|
|
found.add(short_channel_id) |
|
|
|
return len(found) |
|
|
|
|
|
|
|
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, |
|
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: |
|
|
|