Browse Source

get rid of sql_alchemy

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
238f3c949c
  1. 1
      contrib/requirements/requirements.txt
  2. 174
      electrum/channel_db.py
  3. 100
      electrum/lnwatcher.py
  4. 25
      electrum/sql_db.py

1
contrib/requirements/requirements.txt

@ -11,4 +11,3 @@ aiohttp_socks
certifi certifi
bitstring bitstring
pycryptodomex>=3.7 pycryptodomex>=3.7
sqlalchemy>=1.3.0b3

174
electrum/channel_db.py

@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii import binascii
import base64 import base64
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql from .sql_db import SqlDB, sql
from . import constants from . import constants
@ -66,7 +62,6 @@ def validate_features(features : int):
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits() raise UnknownEvenFeatureBits()
Base = declarative_base()
FLAG_DISABLE = 1 << 1 FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0 FLAG_DIRECTION = 1 << 0
@ -193,57 +188,45 @@ class Address(NamedTuple):
port: int port: int
last_connected_date: int last_connected_date: int
create_channel_info = """
class ChannelInfoBase(Base): CREATE TABLE IF NOT EXISTS channel_info (
__tablename__ = 'channel_info' short_channel_id VARCHAR(64),
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') node1_id VARCHAR(66),
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) node2_id VARCHAR(66),
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) capacity_sat INTEGER,
capacity_sat = Column(Integer) PRIMARY KEY(short_channel_id)
def to_nametuple(self): )"""
return ChannelInfo(
short_channel_id=self.short_channel_id, create_policy = """
node1_id=self.node1_id, CREATE TABLE IF NOT EXISTS policy (
node2_id=self.node2_id, key VARCHAR(66),
capacity_sat=self.capacity_sat cltv_expiry_delta INTEGER NOT NULL,
) htlc_minimum_msat INTEGER NOT NULL,
htlc_maximum_msat INTEGER,
class PolicyBase(Base): fee_base_msat INTEGER NOT NULL,
__tablename__ = 'policy' fee_proportional_millionths INTEGER NOT NULL,
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') channel_flags INTEGER NOT NULL,
cltv_expiry_delta = Column(Integer, nullable=False) timestamp INTEGER NOT NULL,
htlc_minimum_msat = Column(Integer, nullable=False) PRIMARY KEY(key)
htlc_maximum_msat = Column(Integer) )"""
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False) create_address = """
channel_flags = Column(Integer, nullable=False) CREATE TABLE IF NOT EXISTS address (
timestamp = Column(Integer, nullable=False) node_id VARCHAR(66),
host STRING(256),
def to_nametuple(self): port INTEGER NOT NULL,
return Policy( timestamp INTEGER,
key=self.key, PRIMARY KEY(node_id, host, port)
cltv_expiry_delta=self.cltv_expiry_delta, )"""
htlc_minimum_msat=self.htlc_minimum_msat,
htlc_maximum_msat=self.htlc_maximum_msat, create_node_info = """
fee_base_msat= self.fee_base_msat, CREATE TABLE IF NOT EXISTS node_info (
fee_proportional_millionths = self.fee_proportional_millionths, node_id VARCHAR(66),
channel_flags=self.channel_flags, features INTEGER NOT NULL,
timestamp=self.timestamp timestamp INTEGER NOT NULL,
) alias STRING(64),
PRIMARY KEY(node_id)
class NodeInfoBase(Base): )"""
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
class AddressBase(Base):
__tablename__ = 'address'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
host = Column(String(256))
port = Column(Integer)
last_connected_date = Column(Integer(), nullable=True)
class ChannelDB(SqlDB): class ChannelDB(SqlDB):
@ -252,7 +235,7 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db') path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base, commit_interval=100) super().__init__(network, path, commit_interval=100)
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
@ -276,16 +259,7 @@ class ChannelDB(SqlDB):
now = int(time.time()) now = int(time.time())
node_id = peer.pubkey node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now)) self._addresses[node_id].add((peer.host, peer.port, now))
self.save_address(node_id, peer, now) self.save_node_address(node_id, peer, now)
@sql
def save_address(self, node_id, peer, now):
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
if addr:
addr.last_connected_date = now
else:
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
self.DBSession.add(addr)
def get_200_randomly_sorted_nodes_not_in(self, node_ids): def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids unshuffled = set(self._nodes.keys()) - node_ids
@ -394,17 +368,47 @@ class ChannelDB(SqlDB):
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False) orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
assert len(good) == 1 assert len(good) == 1
def create_database(self):
c = self.conn.cursor()
c.execute(create_node_info)
c.execute(create_address)
c.execute(create_policy)
c.execute(create_channel_info)
self.conn.commit()
@sql @sql
def save_policy(self, policy): def save_policy(self, policy):
self.DBSession.execute(PolicyBase.__table__.insert().values(policy)) c = self.conn.cursor()
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy))
@sql @sql
def delete_policy(self, short_channel_id, node_id): def delete_policy(self, short_channel_id, node_id):
self.DBSession.execute(PolicyBase.__table__.delete().values(policy)) c = self.conn.cursor()
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
@sql @sql
def save_channel(self, channel_info): def save_channel(self, channel_info):
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info)) c = self.conn.cursor()
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
@sql
def save_node(self, node_info):
c = self.conn.cursor()
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
@sql
def save_node_address(self, node_id, peer, now):
c = self.conn.cursor()
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
@sql
def save_node_addresses(self, node_id, node_addresses):
c = self.conn.cursor()
for addr in node_addresses:
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
r = c.fetchall()
if r == []:
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
def verify_channel_update(self, payload): def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id'] short_channel_id = payload['short_channel_id']
@ -418,7 +422,6 @@ class ChannelDB(SqlDB):
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
old_addr = None old_addr = None
new_nodes = {} new_nodes = {}
new_addresses = {}
for msg_payload in msg_payloads: for msg_payload in msg_payloads:
try: try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload) node_info, node_addresses = NodeInfo.from_msg(msg_payload)
@ -445,17 +448,6 @@ class ChannelDB(SqlDB):
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts() self.update_counts()
@sql
def save_node_addresses(self, node_if, node_addresses):
for new_addr in node_addresses:
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
if not old_addr:
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
@sql
def save_node(self, node_info):
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
def get_routing_policy_for_channel(self, start_node_id: bytes, def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]: short_channel_id: bytes) -> Optional[bytes]:
if not start_node_id or not short_channel_id: return None if not start_node_id or not short_channel_id: return None
@ -506,12 +498,18 @@ class ChannelDB(SqlDB):
@sql @sql
@profiler @profiler
def load_data(self): def load_data(self):
for x in self.DBSession.query(AddressBase).all(): c = self.conn.cursor()
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0))) c.execute("""SELECT * FROM address""")
for x in self.DBSession.query(ChannelInfoBase).all(): for x in c:
self._channels[x.short_channel_id] = x.to_nametuple() node_id, host, port, timestamp = x
for x in self.DBSession.query(PolicyBase).filter_by().all(): self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
p = x.to_nametuple() c.execute("""SELECT * FROM channel_info""")
for x in c:
ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM policy""")
for x in c:
p = Policy(*x)
self._policies[(p.start_node, p.short_channel_id)] = p self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values(): for channel_info in self._channels.values():
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)

100
electrum/lnwatcher.py

@ -13,12 +13,7 @@ from enum import IntEnum, auto
from typing import NamedTuple, Dict from typing import NamedTuple, Dict
import jsonrpclib import jsonrpclib
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql from .sql_db import SqlDB, sql
from .util import bh2u, bfh, log_exceptions, ignore_exceptions from .util import bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet from . import wallet
from .storage import WalletStorage from .storage import WalletStorage
@ -42,80 +37,105 @@ class TxMinedDepth(IntEnum):
FREE = auto() FREE = auto()
Base = declarative_base() create_sweep_txs="""
CREATE TABLE IF NOT EXISTS sweep_txs (
class SweepTx(Base): funding_outpoint VARCHAR(34) NOT NULL,
__tablename__ = 'sweep_txs' "index" INTEGER NOT NULL,
funding_outpoint = Column(String(34), primary_key=True) prevout VARCHAR(34),
index = Column(Integer(), primary_key=True) tx VARCHAR,
prevout = Column(String(34)) PRIMARY KEY(funding_outpoint, "index")
tx = Column(String()) )"""
class ChannelInfo(Base):
__tablename__ = 'channel_info'
outpoint = Column(String(34), primary_key=True)
address = Column(String(32))
create_channel_info="""
CREATE TABLE IF NOT EXISTS channel_info (
outpoint VARCHAR(34) NOT NULL,
address VARCHAR(32),
PRIMARY KEY(outpoint)
)"""
class SweepStore(SqlDB): class SweepStore(SqlDB):
def __init__(self, path, network): def __init__(self, path, network):
super().__init__(network, path, Base) super().__init__(network, path)
def create_database(self):
c = self.conn.cursor()
c.execute(create_channel_info)
c.execute(create_sweep_txs)
self.conn.commit()
@sql @sql
def get_sweep_tx(self, funding_outpoint, prevout): def get_sweep_tx(self, funding_outpoint, prevout):
return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prevout==prevout).all()] c = self.conn.cursor()
c.execute("SELECT tx FROM sweep_txs WHERE funding_outpoint=? AND prevout=?", (funding_outpoint, prevout))
return [Transaction(bh2u(r[0])) for r in c.fetchall()]
@sql @sql
def get_tx_by_index(self, funding_outpoint, index): def get_tx_by_index(self, funding_outpoint, index):
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none() c = self.conn.cursor()
return str(r.prevout), bh2u(r.tx) c.execute("""SELECT prevout, tx FROM sweep_txs WHERE funding_outpoint=? AND "index"=?""", (funding_outpoint, index))
r = c.fetchone()[0]
return str(r[0]), bh2u(r[1])
@sql @sql
def list_sweep_tx(self): def list_sweep_tx(self):
return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all()) c = self.conn.cursor()
c.execute("SELECT funding_outpoint FROM sweep_txs")
return set([r[0] for r in c.fetchall()])
@sql @sql
def add_sweep_tx(self, funding_outpoint, prevout, tx): def add_sweep_tx(self, funding_outpoint, prevout, tx):
n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() c = self.conn.cursor()
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prevout=prevout, tx=bfh(tx))) c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
self.DBSession.commit() n = int(c.fetchone()[0])
c.execute("""INSERT INTO sweep_txs (funding_outpoint, "index", prevout, tx) VALUES (?,?,?,?)""", (funding_outpoint, n, prevout, bfh(str(tx))))
self.conn.commit()
@sql @sql
def get_num_tx(self, funding_outpoint): def get_num_tx(self, funding_outpoint):
return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()) c = self.conn.cursor()
c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
return int(c.fetchone()[0])
@sql @sql
def remove_sweep_tx(self, funding_outpoint): def remove_sweep_tx(self, funding_outpoint):
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all() c = self.conn.cursor()
for x in r: c.execute("DELETE FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
self.DBSession.delete(x) self.conn.commit()
self.DBSession.commit()
@sql @sql
def add_channel(self, outpoint, address): def add_channel(self, outpoint, address):
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) c = self.conn.cursor()
self.DBSession.commit() c.execute("INSERT INTO channel_info (address, outpoint) VALUES (?,?)", (address, outpoint))
self.conn.commit()
@sql @sql
def remove_channel(self, outpoint): def remove_channel(self, outpoint):
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() c = self.conn.cursor()
self.DBSession.delete(v) c.execute("DELETE FROM channel_info WHERE outpoint=?", (outpoint,))
self.DBSession.commit() self.conn.commit()
@sql @sql
def has_channel(self, outpoint): def has_channel(self, outpoint):
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()) c = self.conn.cursor()
c.execute("SELECT * FROM channel_info WHERE outpoint=?", (outpoint,))
r = c.fetchone()
return r is not None
@sql @sql
def get_address(self, outpoint): def get_address(self, outpoint):
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() c = self.conn.cursor()
return str(r.address) if r else None c.execute("SELECT address FROM channel_info WHERE outpoint=?", (outpoint,))
r = c.fetchone()
return r[0] if r else None
@sql @sql
def list_channel_info(self): def list_channel_info(self):
return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()] c = self.conn.cursor()
c.execute("SELECT address, outpoint FROM channel_info")
return [(r[0], r[1]) for r in c.fetchall()]
class LNWatcher(AddressSynchronizer): class LNWatcher(AddressSynchronizer):

25
electrum/sql_db.py

@ -3,18 +3,11 @@ import concurrent
import queue import queue
import threading import threading
import asyncio import asyncio
import sqlite3
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from .logging import Logger from .logging import Logger
# https://stackoverflow.com/questions/26971050/sqlalchemy-sqlite-too-many-sql-variables
SQLITE_LIMIT_VARIABLE_NUMBER = 999
def sql(func): def sql(func):
"""wrapper for sql methods""" """wrapper for sql methods"""
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
@ -26,9 +19,8 @@ def sql(func):
class SqlDB(Logger): class SqlDB(Logger):
def __init__(self, network, path, base, commit_interval=None): def __init__(self, network, path, commit_interval=None):
Logger.__init__(self) Logger.__init__(self)
self.base = base
self.network = network self.network = network
self.path = path self.path = path
self.commit_interval = commit_interval self.commit_interval = commit_interval
@ -37,13 +29,10 @@ class SqlDB(Logger):
self.sql_thread.start() self.sql_thread.start()
def run_sql(self): def run_sql(self):
#return
self.logger.info("SQL thread started") self.logger.info("SQL thread started")
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) self.conn = sqlite3.connect(self.path)
DBSession = sessionmaker(bind=engine, autoflush=False) self.logger.info("Creating database")
if not os.path.exists(self.path): self.create_database()
self.base.metadata.create_all(engine)
self.DBSession = DBSession()
i = 0 i = 0
while self.network.asyncio_loop.is_running(): while self.network.asyncio_loop.is_running():
try: try:
@ -62,7 +51,7 @@ class SqlDB(Logger):
if self.commit_interval: if self.commit_interval:
i = (i + 1) % self.commit_interval i = (i + 1) % self.commit_interval
if i == 0: if i == 0:
self.DBSession.commit() self.conn.commit()
# write # write
self.DBSession.commit() self.conn.commit()
self.logger.info("SQL thread terminated") self.logger.info("SQL thread terminated")

Loading…
Cancel
Save