diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 62c2256ac..a684bf3e8 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -69,7 +69,6 @@ def validate_features(features : int): raise UnknownEvenFeatureBits() Base = declarative_base() -session_factory = sessionmaker() FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 @@ -228,10 +227,10 @@ class ChannelDB(PrintError): threading.Thread(target=self.sql_thread).start() def sql_thread(self): + self.sql_thread = threading.currentThread() engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) - self.DBSession = scoped_session(session_factory) - self.DBSession.remove() - self.DBSession.configure(bind=engine, autoflush=False) + DBSession = sessionmaker(bind=engine, autoflush=False) + self.DBSession = DBSession() if not os.path.exists(self.path): Base.metadata.create_all(engine) self._update_counts() @@ -248,17 +247,16 @@ class ChannelDB(PrintError): future.set_result(result) # write self.DBSession.commit() - self.DBSession.remove() self.print_error("SQL thread terminated") def sql(func): def wrapper(self, *args, **kwargs): + assert threading.currentThread() != self.sql_thread f = concurrent.futures.Future() self.db_requests.put((f, func, args, kwargs)) return f.result(timeout=10) return wrapper - # not @sql def _update_counts(self): self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_nodes = self.DBSession.query(NodeInfo).count()