Browse Source

get rid of sql_alchemy

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
238f3c949c
  1. 1
      contrib/requirements/requirements.txt
  2. 174
      electrum/channel_db.py
  3. 100
      electrum/lnwatcher.py
  4. 25
      electrum/sql_db.py

1
contrib/requirements/requirements.txt

@ -11,4 +11,3 @@ aiohttp_socks
certifi
bitstring
pycryptodomex>=3.7
sqlalchemy>=1.3.0b3

174
electrum/channel_db.py

@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from . import constants
@ -66,7 +62,6 @@ def validate_features(features : int):
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
Base = declarative_base()
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
@ -193,57 +188,45 @@ class Address(NamedTuple):
port: int
last_connected_date: int
class ChannelInfoBase(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
capacity_sat = Column(Integer)
def to_nametuple(self):
return ChannelInfo(
short_channel_id=self.short_channel_id,
node1_id=self.node1_id,
node2_id=self.node2_id,
capacity_sat=self.capacity_sat
)
class PolicyBase(Base):
__tablename__ = 'policy'
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
cltv_expiry_delta = Column(Integer, nullable=False)
htlc_minimum_msat = Column(Integer, nullable=False)
htlc_maximum_msat = Column(Integer)
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
def to_nametuple(self):
return Policy(
key=self.key,
cltv_expiry_delta=self.cltv_expiry_delta,
htlc_minimum_msat=self.htlc_minimum_msat,
htlc_maximum_msat=self.htlc_maximum_msat,
fee_base_msat= self.fee_base_msat,
fee_proportional_millionths = self.fee_proportional_millionths,
channel_flags=self.channel_flags,
timestamp=self.timestamp
)
class NodeInfoBase(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
class AddressBase(Base):
__tablename__ = 'address'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
host = Column(String(256))
port = Column(Integer)
last_connected_date = Column(Integer(), nullable=True)
create_channel_info = """
CREATE TABLE IF NOT EXISTS channel_info (
short_channel_id VARCHAR(64),
node1_id VARCHAR(66),
node2_id VARCHAR(66),
capacity_sat INTEGER,
PRIMARY KEY(short_channel_id)
)"""
create_policy = """
CREATE TABLE IF NOT EXISTS policy (
key VARCHAR(66),
cltv_expiry_delta INTEGER NOT NULL,
htlc_minimum_msat INTEGER NOT NULL,
htlc_maximum_msat INTEGER,
fee_base_msat INTEGER NOT NULL,
fee_proportional_millionths INTEGER NOT NULL,
channel_flags INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
PRIMARY KEY(key)
)"""
create_address = """
CREATE TABLE IF NOT EXISTS address (
node_id VARCHAR(66),
host STRING(256),
port INTEGER NOT NULL,
timestamp INTEGER,
PRIMARY KEY(node_id, host, port)
)"""
create_node_info = """
CREATE TABLE IF NOT EXISTS node_info (
node_id VARCHAR(66),
features INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
alias STRING(64),
PRIMARY KEY(node_id)
)"""
class ChannelDB(SqlDB):
@ -252,7 +235,7 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base, commit_interval=100)
super().__init__(network, path, commit_interval=100)
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
@ -276,16 +259,7 @@ class ChannelDB(SqlDB):
now = int(time.time())
node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now))
self.save_address(node_id, peer, now)
@sql
def save_address(self, node_id, peer, now):
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
if addr:
addr.last_connected_date = now
else:
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
self.DBSession.add(addr)
self.save_node_address(node_id, peer, now)
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids
@ -394,17 +368,47 @@ class ChannelDB(SqlDB):
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
assert len(good) == 1
def create_database(self):
c = self.conn.cursor()
c.execute(create_node_info)
c.execute(create_address)
c.execute(create_policy)
c.execute(create_channel_info)
self.conn.commit()
@sql
def save_policy(self, policy):
self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
c = self.conn.cursor()
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy))
@sql
def delete_policy(self, short_channel_id, node_id):
self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
c = self.conn.cursor()
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
@sql
def save_channel(self, channel_info):
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
c = self.conn.cursor()
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
@sql
def save_node(self, node_info):
c = self.conn.cursor()
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
@sql
def save_node_address(self, node_id, peer, now):
c = self.conn.cursor()
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
@sql
def save_node_addresses(self, node_id, node_addresses):
c = self.conn.cursor()
for addr in node_addresses:
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
r = c.fetchall()
if r == []:
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
@ -418,7 +422,6 @@ class ChannelDB(SqlDB):
msg_payloads = [msg_payloads]
old_addr = None
new_nodes = {}
new_addresses = {}
for msg_payload in msg_payloads:
try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
@ -445,17 +448,6 @@ class ChannelDB(SqlDB):
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
@sql
def save_node_addresses(self, node_if, node_addresses):
for new_addr in node_addresses:
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
if not old_addr:
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
@sql
def save_node(self, node_info):
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]:
if not start_node_id or not short_channel_id: return None
@ -506,12 +498,18 @@ class ChannelDB(SqlDB):
@sql
@profiler
def load_data(self):
for x in self.DBSession.query(AddressBase).all():
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
for x in self.DBSession.query(ChannelInfoBase).all():
self._channels[x.short_channel_id] = x.to_nametuple()
for x in self.DBSession.query(PolicyBase).filter_by().all():
p = x.to_nametuple()
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
node_id, host, port, timestamp = x
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
c.execute("""SELECT * FROM channel_info""")
for x in c:
ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM policy""")
for x in c:
p = Policy(*x)
self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values():
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)

100
electrum/lnwatcher.py

@ -13,12 +13,7 @@ from enum import IntEnum, auto
from typing import NamedTuple, Dict
import jsonrpclib
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from .util import bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet
from .storage import WalletStorage
@ -42,80 +37,105 @@ class TxMinedDepth(IntEnum):
FREE = auto()
Base = declarative_base()
class SweepTx(Base):
__tablename__ = 'sweep_txs'
funding_outpoint = Column(String(34), primary_key=True)
index = Column(Integer(), primary_key=True)
prevout = Column(String(34))
tx = Column(String())
class ChannelInfo(Base):
__tablename__ = 'channel_info'
outpoint = Column(String(34), primary_key=True)
address = Column(String(32))
create_sweep_txs="""
CREATE TABLE IF NOT EXISTS sweep_txs (
funding_outpoint VARCHAR(34) NOT NULL,
"index" INTEGER NOT NULL,
prevout VARCHAR(34),
tx VARCHAR,
PRIMARY KEY(funding_outpoint, "index")
)"""
create_channel_info="""
CREATE TABLE IF NOT EXISTS channel_info (
outpoint VARCHAR(34) NOT NULL,
address VARCHAR(32),
PRIMARY KEY(outpoint)
)"""
class SweepStore(SqlDB):
def __init__(self, path, network):
super().__init__(network, path, Base)
super().__init__(network, path)
def create_database(self):
c = self.conn.cursor()
c.execute(create_channel_info)
c.execute(create_sweep_txs)
self.conn.commit()
@sql
def get_sweep_tx(self, funding_outpoint, prevout):
return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prevout==prevout).all()]
c = self.conn.cursor()
c.execute("SELECT tx FROM sweep_txs WHERE funding_outpoint=? AND prevout=?", (funding_outpoint, prevout))
return [Transaction(bh2u(r[0])) for r in c.fetchall()]
@sql
def get_tx_by_index(self, funding_outpoint, index):
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
return str(r.prevout), bh2u(r.tx)
c = self.conn.cursor()
c.execute("""SELECT prevout, tx FROM sweep_txs WHERE funding_outpoint=? AND "index"=?""", (funding_outpoint, index))
r = c.fetchone()[0]
return str(r[0]), bh2u(r[1])
@sql
def list_sweep_tx(self):
return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
c = self.conn.cursor()
c.execute("SELECT funding_outpoint FROM sweep_txs")
return set([r[0] for r in c.fetchall()])
@sql
def add_sweep_tx(self, funding_outpoint, prevout, tx):
n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prevout=prevout, tx=bfh(tx)))
self.DBSession.commit()
c = self.conn.cursor()
c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
n = int(c.fetchone()[0])
c.execute("""INSERT INTO sweep_txs (funding_outpoint, "index", prevout, tx) VALUES (?,?,?,?)""", (funding_outpoint, n, prevout, bfh(str(tx))))
self.conn.commit()
@sql
def get_num_tx(self, funding_outpoint):
return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
c = self.conn.cursor()
c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
return int(c.fetchone()[0])
@sql
def remove_sweep_tx(self, funding_outpoint):
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
for x in r:
self.DBSession.delete(x)
self.DBSession.commit()
c = self.conn.cursor()
c.execute("DELETE FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
self.conn.commit()
@sql
def add_channel(self, outpoint, address):
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
self.DBSession.commit()
c = self.conn.cursor()
c.execute("INSERT INTO channel_info (address, outpoint) VALUES (?,?)", (address, outpoint))
self.conn.commit()
@sql
def remove_channel(self, outpoint):
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
self.DBSession.delete(v)
self.DBSession.commit()
c = self.conn.cursor()
c.execute("DELETE FROM channel_info WHERE outpoint=?", (outpoint,))
self.conn.commit()
@sql
def has_channel(self, outpoint):
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none())
c = self.conn.cursor()
c.execute("SELECT * FROM channel_info WHERE outpoint=?", (outpoint,))
r = c.fetchone()
return r is not None
@sql
def get_address(self, outpoint):
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
return str(r.address) if r else None
c = self.conn.cursor()
c.execute("SELECT address FROM channel_info WHERE outpoint=?", (outpoint,))
r = c.fetchone()
return r[0] if r else None
@sql
def list_channel_info(self):
return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
c = self.conn.cursor()
c.execute("SELECT address, outpoint FROM channel_info")
return [(r[0], r[1]) for r in c.fetchall()]
class LNWatcher(AddressSynchronizer):

25
electrum/sql_db.py

@ -3,18 +3,11 @@ import concurrent
import queue
import threading
import asyncio
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
import sqlite3
from .logging import Logger
# https://stackoverflow.com/questions/26971050/sqlalchemy-sqlite-too-many-sql-variables
SQLITE_LIMIT_VARIABLE_NUMBER = 999
def sql(func):
"""wrapper for sql methods"""
def wrapper(self, *args, **kwargs):
@ -26,9 +19,8 @@ def sql(func):
class SqlDB(Logger):
def __init__(self, network, path, base, commit_interval=None):
def __init__(self, network, path, commit_interval=None):
Logger.__init__(self)
self.base = base
self.network = network
self.path = path
self.commit_interval = commit_interval
@ -37,13 +29,10 @@ class SqlDB(Logger):
self.sql_thread.start()
def run_sql(self):
#return
self.logger.info("SQL thread started")
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
if not os.path.exists(self.path):
self.base.metadata.create_all(engine)
self.DBSession = DBSession()
self.conn = sqlite3.connect(self.path)
self.logger.info("Creating database")
self.create_database()
i = 0
while self.network.asyncio_loop.is_running():
try:
@ -62,7 +51,7 @@ class SqlDB(Logger):
if self.commit_interval:
i = (i + 1) % self.commit_interval
if i == 0:
self.DBSession.commit()
self.conn.commit()
# write
self.DBSession.commit()
self.conn.commit()
self.logger.info("SQL thread terminated")

Loading…
Cancel
Save