From 46aa5c19584a1118deabb9dc2a6eba6012d38509 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Tue, 5 Mar 2019 12:20:56 +0100 Subject: [PATCH] lnrouter: perform SQL requests in a separate thread. persist database. --- electrum/lnrouter.py | 81 +++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 821b746c3..833da9534 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -29,11 +29,11 @@ import queue import os import json import threading +import concurrent from collections import defaultdict from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set import binascii import base64 -import asyncio from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy.pool import StaticPool @@ -212,43 +212,59 @@ class Address(Base): port = Column(Integer, primary_key=True) last_connected_date = Column(DateTime(), nullable=False) -class ChannelDB: + +class ChannelDB(PrintError): NUM_MAX_RECENT_PEERS = 20 def __init__(self, network: 'Network'): self.network = network - self.num_nodes = 0 self.num_channels = 0 - self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3') - - # (intentionally not persisted) self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] - self.ca_verifier = LNChannelVerifier(network, self) + self.db_requests = queue.Queue() + threading.Thread(target=self.sql_thread).start() - self.network.run_from_another_thread(self.sqlinit()) - - async def sqlinit(self): - """ - this has to run on the async thread since that is where - the lnpeer loop is running from, which will do call in here - """ + def sql_thread(self): engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) self.DBSession = scoped_session(session_factory) self.DBSession.remove() self.DBSession.configure(bind=engine, autoflush=False) + if not os.path.exists(self.path): + Base.metadata.create_all(engine) + self._update_counts() + while self.network.asyncio_loop.is_running(): + try: + future, func, args, kwargs = self.db_requests.get(timeout=0.1) + except queue.Empty: + continue + try: + result = func(self, *args, **kwargs) + except BaseException as e: + future.set_exception(e) + continue + future.set_result(result) + # write + self.DBSession.commit() + self.DBSession.remove() + self.print_error("SQL thread terminated") - Base.metadata.drop_all(engine) - Base.metadata.create_all(engine) + def sql(func): + def wrapper(self, *args, **kwargs): + f = concurrent.futures.Future() + self.db_requests.put((f, func, args, kwargs)) + return f.result(timeout=10) + return wrapper - def update_counts(self): + # not @sql + def _update_counts(self): self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_nodes = self.DBSession.query(NodeInfo).count() - def add_recent_peer(self, peer : LNPeerAddr): + @sql + def add_recent_peer(self, peer: LNPeerAddr): addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() if addr is None: addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) @@ -257,6 +273,7 @@ class ChannelDB: self.DBSession.add(addr) self.DBSession.commit() + @sql def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): unshuffled = self.DBSession \ .query(NodeInfo) \ @@ -265,15 +282,14 @@ class ChannelDB: .all() return random.sample(unshuffled, len(unshuffled)) + @sql def nodes_get(self, node_id): - return self.network.run_from_another_thread(self._nodes_get(node_id)) - - async def _nodes_get(self, node_id): return self.DBSession \ .query(NodeInfo) \ .filter_by(node_id = node_id.hex()) \ .one_or_none() + @sql def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: adr_db = self.DBSession \ .query(Address) \ @@ -284,6 +300,7 @@ class ChannelDB: return None return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id)) + @sql def get_recent_peers(self): return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \ .query(Address) \ @@ -291,9 +308,11 @@ class ChannelDB: .order_by(Address.last_connected_date.desc()) \ .limit(self.NUM_MAX_RECENT_PEERS)] + @sql def get_channel_info(self, channel_id: bytes): - return self.chan_query_for_id(channel_id).one_or_none() + return self._chan_query_for_id(channel_id).one_or_none() + @sql def get_channels_for_node(self, node_id): """Returns the set of channels that have node_id as one of the endpoints.""" condition = or_( @@ -302,6 +321,7 @@ class ChannelDB: rows = self.DBSession.query(ChannelInfo).filter(condition).all() return [bytes.fromhex(x.short_channel_id) for x in rows] + @sql def missing_short_chan_ids(self) -> Set[int]: expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) @@ -318,13 +338,15 @@ class ChannelDB: return chan_ids_from_id2 return set() + @sql def add_verified_channel_info(self, short_id, capacity): # called from lnchannelverifier - channel_info = self.get_channel_info(short_id) + channel_info = self._chan_query_for_id(short_id).one_or_none() channel_info.trusted = True channel_info.capacity = capacity self.DBSession.commit() + @sql @profiler def on_channel_announcement(self, msg_payloads, trusted=False): if type(msg_payloads) is dict: @@ -344,9 +366,10 @@ class ChannelDB: self.DBSession.add(channel_info) if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) self.DBSession.commit() + self._update_counts() self.network.trigger_callback('ln_status') - self.update_counts() + @sql @profiler def on_channel_update(self, msg_payloads, trusted=False): if type(msg_payloads) is dict: @@ -364,6 +387,7 @@ class ChannelDB: self._update_channel_info(channel_info, msg_payload, trusted=trusted) self.DBSession.commit() + @sql @profiler def on_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: @@ -411,8 +435,8 @@ class ChannelDB: if old_addr: del old_addr self.DBSession.commit() + self._update_counts() self.network.trigger_callback('ln_status') - self.update_counts() def get_routing_policy_for_channel(self, start_node_id: bytes, short_channel_id: bytes) -> Optional[bytes]: @@ -431,11 +455,12 @@ class ChannelDB: short_channel_id = msg_payload['short_channel_id'] self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload + @sql def remove_channel(self, short_channel_id): - self.chan_query_for_id(short_channel_id).delete('evaluate') + self._chan_query_for_id(short_channel_id).delete('evaluate') self.DBSession.commit() - def chan_query_for_id(self, short_channel_id) -> Query: + def _chan_query_for_id(self, short_channel_id) -> Query: return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) def print_graph(self, full_ids=False): @@ -495,6 +520,7 @@ class ChannelDB: old_policy.channel_flags = new_policy.channel_flags old_policy.timestamp = new_policy.timestamp + @sql def get_policy_for_node(self, node) -> Optional['Policy']: """ raises when initiator/non-initiator both unequal node @@ -507,6 +533,7 @@ class ChannelDB: n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() return n2 + @sql def get_node_addresses(self, node_info): return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()