Browse Source

create parent class for sql databases

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
d8e9a9a49e
  1. 51
      electrum/lnrouter.py
  2. 52
      electrum/lnwatcher.py
  3. 51
      electrum/sql_db.py

51
electrum/lnrouter.py

@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii import binascii
import base64 import base64
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session from .sql_db import SqlDB, sql
from . import constants from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
@ -212,50 +210,25 @@ class Address(Base):
last_connected_date = Column(DateTime(), nullable=False) last_connected_date = Column(DateTime(), nullable=False)
class ChannelDB(PrintError):
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
self.network = network path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base)
print(Base)
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
self.db_requests = queue.Queue() self.update_counts()
threading.Thread(target=self.sql_thread).start()
@sql
def sql_thread(self): def update_counts(self):
self.sql_thread = threading.currentThread()
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
self._update_counts() self._update_counts()
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
def _update_counts(self): def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_channels = self.DBSession.query(ChannelInfo).count()

52
electrum/lnwatcher.py

@ -11,9 +11,14 @@ from collections import defaultdict
import asyncio import asyncio
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import NamedTuple, Dict from typing import NamedTuple, Dict
import jsonrpclib import jsonrpclib
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet from . import wallet
from .storage import WalletStorage from .storage import WalletStorage
@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
FREE = auto() FREE = auto()
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
Base = declarative_base() Base = declarative_base()
class SweepTx(Base): class SweepTx(Base):
@ -60,42 +57,11 @@ class ChannelInfo(Base):
outpoint = Column(String(34)) outpoint = Column(String(34))
class SweepStore(PrintError):
def __init__(self, path, network): class SweepStore(SqlDB):
PrintError.__init__(self)
self.path = path
self.network = network
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func): def __init__(self, path, network):
def wrapper(self, *args, **kwargs): super().__init__(network, path, Base)
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
@sql @sql
def get_sweep_tx(self, funding_outpoint, prev_txid): def get_sweep_tx(self, funding_outpoint, prev_txid):

51
electrum/sql_db.py

@ -0,0 +1,51 @@
import os
import concurrent
import queue
import threading
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from .util import PrintError
def sql(func):
"""wrapper for sql methods"""
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
class SqlDB(PrintError):
def __init__(self, network, path, base):
self.base = base
self.network = network
self.path = path
self.db_requests = queue.Queue()
self.sql_thread = threading.Thread(target=self.run_sql)
self.sql_thread.start()
def run_sql(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
self.base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
Loading…
Cancel
Save