Browse Source

lnrouter: perform SQL requests in a separate thread. persist database.

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
46aa5c1958
  1. 81
      electrum/lnrouter.py

81
electrum/lnrouter.py

@ -29,11 +29,11 @@ import queue
import os import os
import json import json
import threading import threading
import concurrent
from collections import defaultdict from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
import binascii import binascii
import base64 import base64
import asyncio
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
@ -212,43 +212,59 @@ class Address(Base):
port = Column(Integer, primary_key=True) port = Column(Integer, primary_key=True)
last_connected_date = Column(DateTime(), nullable=False) last_connected_date = Column(DateTime(), nullable=False)
class ChannelDB:
class ChannelDB(PrintError):
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
self.network = network self.network = network
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3') self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
# (intentionally not persisted)
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
self.network.run_from_another_thread(self.sqlinit()) def sql_thread(self):
async def sqlinit(self):
"""
this has to run on the async thread since that is where
the lnpeer loop is running from, which will do call in here
"""
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
self.DBSession = scoped_session(session_factory) self.DBSession = scoped_session(session_factory)
self.DBSession.remove() self.DBSession.remove()
self.DBSession.configure(bind=engine, autoflush=False) self.DBSession.configure(bind=engine, autoflush=False)
if not os.path.exists(self.path):
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
self._update_counts()
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.DBSession.remove()
self.print_error("SQL thread terminated")
def update_counts(self): def sql(func):
def wrapper(self, *args, **kwargs):
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
# not @sql
def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_channels = self.DBSession.query(ChannelInfo).count()
self.num_nodes = self.DBSession.query(NodeInfo).count() self.num_nodes = self.DBSession.query(NodeInfo).count()
def add_recent_peer(self, peer : LNPeerAddr): @sql
def add_recent_peer(self, peer: LNPeerAddr):
addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None: if addr is None:
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
@ -257,6 +273,7 @@ class ChannelDB:
self.DBSession.add(addr) self.DBSession.add(addr)
self.DBSession.commit() self.DBSession.commit()
@sql
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = self.DBSession \ unshuffled = self.DBSession \
.query(NodeInfo) \ .query(NodeInfo) \
@ -265,15 +282,14 @@ class ChannelDB:
.all() .all()
return random.sample(unshuffled, len(unshuffled)) return random.sample(unshuffled, len(unshuffled))
@sql
def nodes_get(self, node_id): def nodes_get(self, node_id):
return self.network.run_from_another_thread(self._nodes_get(node_id))
async def _nodes_get(self, node_id):
return self.DBSession \ return self.DBSession \
.query(NodeInfo) \ .query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \ .filter_by(node_id = node_id.hex()) \
.one_or_none() .one_or_none()
@sql
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
adr_db = self.DBSession \ adr_db = self.DBSession \
.query(Address) \ .query(Address) \
@ -284,6 +300,7 @@ class ChannelDB:
return None return None
return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id)) return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
@sql
def get_recent_peers(self): def get_recent_peers(self):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \ return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \
.query(Address) \ .query(Address) \
@ -291,9 +308,11 @@ class ChannelDB:
.order_by(Address.last_connected_date.desc()) \ .order_by(Address.last_connected_date.desc()) \
.limit(self.NUM_MAX_RECENT_PEERS)] .limit(self.NUM_MAX_RECENT_PEERS)]
@sql
def get_channel_info(self, channel_id: bytes): def get_channel_info(self, channel_id: bytes):
return self.chan_query_for_id(channel_id).one_or_none() return self._chan_query_for_id(channel_id).one_or_none()
@sql
def get_channels_for_node(self, node_id): def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints.""" """Returns the set of channels that have node_id as one of the endpoints."""
condition = or_( condition = or_(
@ -302,6 +321,7 @@ class ChannelDB:
rows = self.DBSession.query(ChannelInfo).filter(condition).all() rows = self.DBSession.query(ChannelInfo).filter(condition).all()
return [bytes.fromhex(x.short_channel_id) for x in rows] return [bytes.fromhex(x.short_channel_id) for x in rows]
@sql
def missing_short_chan_ids(self) -> Set[int]: def missing_short_chan_ids(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
@ -318,13 +338,15 @@ class ChannelDB:
return chan_ids_from_id2 return chan_ids_from_id2
return set() return set()
@sql
def add_verified_channel_info(self, short_id, capacity): def add_verified_channel_info(self, short_id, capacity):
# called from lnchannelverifier # called from lnchannelverifier
channel_info = self.get_channel_info(short_id) channel_info = self._chan_query_for_id(short_id).one_or_none()
channel_info.trusted = True channel_info.trusted = True
channel_info.capacity = capacity channel_info.capacity = capacity
self.DBSession.commit() self.DBSession.commit()
@sql
@profiler @profiler
def on_channel_announcement(self, msg_payloads, trusted=False): def on_channel_announcement(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
@ -344,9 +366,10 @@ class ChannelDB:
self.DBSession.add(channel_info) self.DBSession.add(channel_info)
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
self.DBSession.commit() self.DBSession.commit()
self._update_counts()
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
self.update_counts()
@sql
@profiler @profiler
def on_channel_update(self, msg_payloads, trusted=False): def on_channel_update(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
@ -364,6 +387,7 @@ class ChannelDB:
self._update_channel_info(channel_info, msg_payload, trusted=trusted) self._update_channel_info(channel_info, msg_payload, trusted=trusted)
self.DBSession.commit() self.DBSession.commit()
@sql
@profiler @profiler
def on_node_announcement(self, msg_payloads): def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
@ -411,8 +435,8 @@ class ChannelDB:
if old_addr: if old_addr:
del old_addr del old_addr
self.DBSession.commit() self.DBSession.commit()
self._update_counts()
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
self.update_counts()
def get_routing_policy_for_channel(self, start_node_id: bytes, def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]: short_channel_id: bytes) -> Optional[bytes]:
@ -431,11 +455,12 @@ class ChannelDB:
short_channel_id = msg_payload['short_channel_id'] short_channel_id = msg_payload['short_channel_id']
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql
def remove_channel(self, short_channel_id): def remove_channel(self, short_channel_id):
self.chan_query_for_id(short_channel_id).delete('evaluate') self._chan_query_for_id(short_channel_id).delete('evaluate')
self.DBSession.commit() self.DBSession.commit()
def chan_query_for_id(self, short_channel_id) -> Query: def _chan_query_for_id(self, short_channel_id) -> Query:
return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
def print_graph(self, full_ids=False): def print_graph(self, full_ids=False):
@ -495,6 +520,7 @@ class ChannelDB:
old_policy.channel_flags = new_policy.channel_flags old_policy.channel_flags = new_policy.channel_flags
old_policy.timestamp = new_policy.timestamp old_policy.timestamp = new_policy.timestamp
@sql
def get_policy_for_node(self, node) -> Optional['Policy']: def get_policy_for_node(self, node) -> Optional['Policy']:
""" """
raises when initiator/non-initiator both unequal node raises when initiator/non-initiator both unequal node
@ -507,6 +533,7 @@ class ChannelDB:
n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
return n2 return n2
@sql
def get_node_addresses(self, node_info): def get_node_addresses(self, node_info):
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()

Loading…
Cancel
Save