diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index a684bf3e8..670762b66 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK import binascii import base64 -from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean -from sqlalchemy.pool import StaticPool -from sqlalchemy.orm import sessionmaker +from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy.orm.query import Query from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import not_, or_ -from sqlalchemy.orm import scoped_session +from .sql_db import SqlDB, sql from . import constants from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits @@ -212,50 +210,25 @@ class Address(Base): last_connected_date = Column(DateTime(), nullable=False) -class ChannelDB(PrintError): + + +class ChannelDB(SqlDB): NUM_MAX_RECENT_PEERS = 20 def __init__(self, network: 'Network'): - self.network = network + path = os.path.join(get_headers_dir(network.config), 'channel_db') + super().__init__(network, path, Base) + print(Base) self.num_nodes = 0 self.num_channels = 0 - self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3') self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self.ca_verifier = LNChannelVerifier(network, self) - self.db_requests = queue.Queue() - threading.Thread(target=self.sql_thread).start() - - def sql_thread(self): - self.sql_thread = threading.currentThread() - engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) - DBSession = sessionmaker(bind=engine, autoflush=False) - self.DBSession = DBSession() - if not os.path.exists(self.path): - Base.metadata.create_all(engine) + self.update_counts() + + @sql + def update_counts(self): self._update_counts() - while self.network.asyncio_loop.is_running(): - try: - future, func, args, kwargs = self.db_requests.get(timeout=0.1) - except queue.Empty: - continue - try: - result = func(self, *args, **kwargs) - except BaseException as e: - future.set_exception(e) - continue - future.set_result(result) - # write - self.DBSession.commit() - self.print_error("SQL thread terminated") - - def sql(func): - def wrapper(self, *args, **kwargs): - assert threading.currentThread() != self.sql_thread - f = concurrent.futures.Future() - self.db_requests.put((f, func, args, kwargs)) - return f.result(timeout=10) - return wrapper def _update_counts(self): self.num_channels = self.DBSession.query(ChannelInfo).count() diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 968bba0eb..6bbad480c 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -11,9 +11,14 @@ from collections import defaultdict import asyncio from enum import IntEnum, auto from typing import NamedTuple, Dict - import jsonrpclib +from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean +from sqlalchemy.orm.query import Query +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import not_, or_ +from .sql_db import SqlDB, sql + from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions from . import wallet from .storage import WalletStorage @@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum): FREE = auto() -from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean -from sqlalchemy.pool import StaticPool -from sqlalchemy.orm import sessionmaker -from sqlalchemy.orm.query import Query -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.sql import not_, or_ -from sqlalchemy.orm import scoped_session - Base = declarative_base() class SweepTx(Base): @@ -60,42 +57,11 @@ class ChannelInfo(Base): outpoint = Column(String(34)) -class SweepStore(PrintError): - def __init__(self, path, network): - PrintError.__init__(self) - self.path = path - self.network = network - self.db_requests = queue.Queue() - threading.Thread(target=self.sql_thread).start() - - def sql_thread(self): - engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool) - DBSession = sessionmaker(bind=engine, autoflush=False) - self.DBSession = DBSession() - if not os.path.exists(self.path): - Base.metadata.create_all(engine) - while self.network.asyncio_loop.is_running(): - try: - future, func, args, kwargs = self.db_requests.get(timeout=0.1) - except queue.Empty: - continue - try: - result = func(self, *args, **kwargs) - except BaseException as e: - future.set_exception(e) - continue - future.set_result(result) - # write - self.DBSession.commit() - self.print_error("SQL thread terminated") +class SweepStore(SqlDB): - def sql(func): - def wrapper(self, *args, **kwargs): - f = concurrent.futures.Future() - self.db_requests.put((f, func, args, kwargs)) - return f.result(timeout=10) - return wrapper + def __init__(self, path, network): + super().__init__(network, path, Base) @sql def get_sweep_tx(self, funding_outpoint, prev_txid): diff --git a/electrum/sql_db.py b/electrum/sql_db.py new file mode 100644 index 000000000..d62460f21 --- /dev/null +++ b/electrum/sql_db.py @@ -0,0 +1,51 @@ +import os +import concurrent +import queue +import threading + +from sqlalchemy import create_engine +from sqlalchemy.pool import StaticPool +from sqlalchemy.orm import sessionmaker + +from .util import PrintError + + +def sql(func): + """wrapper for sql methods""" + def wrapper(self, *args, **kwargs): + assert threading.currentThread() != self.sql_thread + f = concurrent.futures.Future() + self.db_requests.put((f, func, args, kwargs)) + return f.result(timeout=10) + return wrapper + +class SqlDB(PrintError): + + def __init__(self, network, path, base): + self.base = base + self.network = network + self.path = path + self.db_requests = queue.Queue() + self.sql_thread = threading.Thread(target=self.run_sql) + self.sql_thread.start() + + def run_sql(self): + engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) + DBSession = sessionmaker(bind=engine, autoflush=False) + self.DBSession = DBSession() + if not os.path.exists(self.path): + self.base.metadata.create_all(engine) + while self.network.asyncio_loop.is_running(): + try: + future, func, args, kwargs = self.db_requests.get(timeout=0.1) + except queue.Empty: + continue + try: + result = func(self, *args, **kwargs) + except BaseException as e: + future.set_exception(e) + continue + future.set_result(result) + # write + self.DBSession.commit() + self.print_error("SQL thread terminated")