Browse Source

sqlite in lnrouter: avoid exceptions on shutdown

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
d2d67f1fe1
  1. 26
      electrum/lnrouter.py

26
electrum/lnrouter.py

@ -36,7 +36,7 @@ import base64
import asyncio import asyncio
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.engine import Engine from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -71,7 +71,6 @@ def validate_features(features : int):
Base = declarative_base() Base = declarative_base()
session_factory = sessionmaker() session_factory = sessionmaker()
DBSession = scoped_session(session_factory) DBSession = scoped_session(session_factory)
engine = None
FLAG_DISABLE = 1 << 1 FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0 FLAG_DIRECTION = 1 << 0
@ -262,27 +261,32 @@ class ChannelDB:
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
global engine
self.network = network self.network = network
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3') self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
engine = create_engine('sqlite:///' + self.path)#, echo=True)
DBSession.remove()
DBSession.configure(bind=engine, autoflush=False)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
self.lock = threading.RLock()
# (intentionally not persisted) # (intentionally not persisted)
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
self.network.run_from_another_thread(self.sqlinit())
async def sqlinit(self):
"""
this has to run on the async thread since that is where
the lnpeer loop is running from, which will do call in here
"""
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession.remove()
DBSession.configure(bind=engine, autoflush=False)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
def update_counts(self): def update_counts(self):
self.num_channels = DBSession.query(ChannelInfoInDB).count() self.num_channels = DBSession.query(ChannelInfoInDB).count()
self.num_nodes = DBSession.query(NodeInfoInDB).count() self.num_nodes = DBSession.query(NodeInfoInDB).count()

Loading…
Cancel
Save