Browse Source

create parent class for sql databases

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
d8e9a9a49e
  1. 51
      electrum/lnrouter.py
  2. 52
      electrum/lnwatcher.py
  3. 51
      electrum/sql_db.py

51
electrum/lnrouter.py

@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
from .sql_db import SqlDB, sql
from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
@ -212,50 +210,25 @@ class Address(Base):
last_connected_date = Column(DateTime(), nullable=False)
class ChannelDB(PrintError):
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base)
print(Base)
self.num_nodes = 0
self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
self.sql_thread = threading.currentThread()
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
self.update_counts()
@sql
def update_counts(self):
self._update_counts()
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count()

52
electrum/lnwatcher.py

@ -11,9 +11,14 @@ from collections import defaultdict
import asyncio
from enum import IntEnum, auto
from typing import NamedTuple, Dict
import jsonrpclib
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet
from .storage import WalletStorage
@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
FREE = auto()
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
Base = declarative_base()
class SweepTx(Base):
@ -60,42 +57,11 @@ class ChannelInfo(Base):
outpoint = Column(String(34))
class SweepStore(PrintError):
def __init__(self, path, network):
PrintError.__init__(self)
self.path = path
self.network = network
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
class SweepStore(SqlDB):
def sql(func):
def wrapper(self, *args, **kwargs):
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
def __init__(self, path, network):
super().__init__(network, path, Base)
@sql
def get_sweep_tx(self, funding_outpoint, prev_txid):

51
electrum/sql_db.py

@ -0,0 +1,51 @@
import os
import concurrent
import queue
import threading
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from .util import PrintError
def sql(func):
"""wrapper for sql methods"""
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
class SqlDB(PrintError):
def __init__(self, network, path, base):
self.base = base
self.network = network
self.path = path
self.db_requests = queue.Queue()
self.sql_thread = threading.Thread(target=self.run_sql)
self.sql_thread.start()
def run_sql(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
self.base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
Loading…
Cancel
Save