|
|
@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK |
|
|
|
import binascii |
|
|
|
import base64 |
|
|
|
|
|
|
|
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean |
|
|
|
from sqlalchemy.orm.query import Query |
|
|
|
from sqlalchemy.ext.declarative import declarative_base |
|
|
|
from sqlalchemy.sql import not_, or_ |
|
|
|
|
|
|
|
from .sql_db import SqlDB, sql |
|
|
|
from . import constants |
|
|
@ -66,7 +62,6 @@ def validate_features(features : int): |
|
|
|
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: |
|
|
|
raise UnknownEvenFeatureBits() |
|
|
|
|
|
|
|
Base = declarative_base() |
|
|
|
|
|
|
|
FLAG_DISABLE = 1 << 1 |
|
|
|
FLAG_DIRECTION = 1 << 0 |
|
|
@ -193,57 +188,45 @@ class Address(NamedTuple): |
|
|
|
port: int |
|
|
|
last_connected_date: int |
|
|
|
|
|
|
|
|
|
|
|
class ChannelInfoBase(Base): |
|
|
|
__tablename__ = 'channel_info' |
|
|
|
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
|
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
|
capacity_sat = Column(Integer) |
|
|
|
def to_nametuple(self): |
|
|
|
return ChannelInfo( |
|
|
|
short_channel_id=self.short_channel_id, |
|
|
|
node1_id=self.node1_id, |
|
|
|
node2_id=self.node2_id, |
|
|
|
capacity_sat=self.capacity_sat |
|
|
|
) |
|
|
|
|
|
|
|
class PolicyBase(Base): |
|
|
|
__tablename__ = 'policy' |
|
|
|
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
cltv_expiry_delta = Column(Integer, nullable=False) |
|
|
|
htlc_minimum_msat = Column(Integer, nullable=False) |
|
|
|
htlc_maximum_msat = Column(Integer) |
|
|
|
fee_base_msat = Column(Integer, nullable=False) |
|
|
|
fee_proportional_millionths = Column(Integer, nullable=False) |
|
|
|
channel_flags = Column(Integer, nullable=False) |
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
|
|
|
|
def to_nametuple(self): |
|
|
|
return Policy( |
|
|
|
key=self.key, |
|
|
|
cltv_expiry_delta=self.cltv_expiry_delta, |
|
|
|
htlc_minimum_msat=self.htlc_minimum_msat, |
|
|
|
htlc_maximum_msat=self.htlc_maximum_msat, |
|
|
|
fee_base_msat= self.fee_base_msat, |
|
|
|
fee_proportional_millionths = self.fee_proportional_millionths, |
|
|
|
channel_flags=self.channel_flags, |
|
|
|
timestamp=self.timestamp |
|
|
|
) |
|
|
|
|
|
|
|
class NodeInfoBase(Base): |
|
|
|
__tablename__ = 'node_info' |
|
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
features = Column(Integer, nullable=False) |
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
alias = Column(String(64), nullable=False) |
|
|
|
|
|
|
|
class AddressBase(Base): |
|
|
|
__tablename__ = 'address' |
|
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
host = Column(String(256)) |
|
|
|
port = Column(Integer) |
|
|
|
last_connected_date = Column(Integer(), nullable=True) |
|
|
|
create_channel_info = """ |
|
|
|
CREATE TABLE IF NOT EXISTS channel_info ( |
|
|
|
short_channel_id VARCHAR(64), |
|
|
|
node1_id VARCHAR(66), |
|
|
|
node2_id VARCHAR(66), |
|
|
|
capacity_sat INTEGER, |
|
|
|
PRIMARY KEY(short_channel_id) |
|
|
|
)""" |
|
|
|
|
|
|
|
create_policy = """ |
|
|
|
CREATE TABLE IF NOT EXISTS policy ( |
|
|
|
key VARCHAR(66), |
|
|
|
cltv_expiry_delta INTEGER NOT NULL, |
|
|
|
htlc_minimum_msat INTEGER NOT NULL, |
|
|
|
htlc_maximum_msat INTEGER, |
|
|
|
fee_base_msat INTEGER NOT NULL, |
|
|
|
fee_proportional_millionths INTEGER NOT NULL, |
|
|
|
channel_flags INTEGER NOT NULL, |
|
|
|
timestamp INTEGER NOT NULL, |
|
|
|
PRIMARY KEY(key) |
|
|
|
)""" |
|
|
|
|
|
|
|
create_address = """ |
|
|
|
CREATE TABLE IF NOT EXISTS address ( |
|
|
|
node_id VARCHAR(66), |
|
|
|
host STRING(256), |
|
|
|
port INTEGER NOT NULL, |
|
|
|
timestamp INTEGER, |
|
|
|
PRIMARY KEY(node_id, host, port) |
|
|
|
)""" |
|
|
|
|
|
|
|
create_node_info = """ |
|
|
|
CREATE TABLE IF NOT EXISTS node_info ( |
|
|
|
node_id VARCHAR(66), |
|
|
|
features INTEGER NOT NULL, |
|
|
|
timestamp INTEGER NOT NULL, |
|
|
|
alias STRING(64), |
|
|
|
PRIMARY KEY(node_id) |
|
|
|
)""" |
|
|
|
|
|
|
|
|
|
|
|
class ChannelDB(SqlDB): |
|
|
@ -252,7 +235,7 @@ class ChannelDB(SqlDB): |
|
|
|
|
|
|
|
def __init__(self, network: 'Network'): |
|
|
|
path = os.path.join(get_headers_dir(network.config), 'channel_db') |
|
|
|
super().__init__(network, path, Base, commit_interval=100) |
|
|
|
super().__init__(network, path, commit_interval=100) |
|
|
|
self.num_nodes = 0 |
|
|
|
self.num_channels = 0 |
|
|
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] |
|
|
@ -276,16 +259,7 @@ class ChannelDB(SqlDB): |
|
|
|
now = int(time.time()) |
|
|
|
node_id = peer.pubkey |
|
|
|
self._addresses[node_id].add((peer.host, peer.port, now)) |
|
|
|
self.save_address(node_id, peer, now) |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_address(self, node_id, peer, now): |
|
|
|
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() |
|
|
|
if addr: |
|
|
|
addr.last_connected_date = now |
|
|
|
else: |
|
|
|
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) |
|
|
|
self.DBSession.add(addr) |
|
|
|
self.save_node_address(node_id, peer, now) |
|
|
|
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids): |
|
|
|
unshuffled = set(self._nodes.keys()) - node_ids |
|
|
@ -394,17 +368,47 @@ class ChannelDB(SqlDB): |
|
|
|
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False) |
|
|
|
assert len(good) == 1 |
|
|
|
|
|
|
|
def create_database(self): |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute(create_node_info) |
|
|
|
c.execute(create_address) |
|
|
|
c.execute(create_policy) |
|
|
|
c.execute(create_channel_info) |
|
|
|
self.conn.commit() |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_policy(self, policy): |
|
|
|
self.DBSession.execute(PolicyBase.__table__.insert().values(policy)) |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy)) |
|
|
|
|
|
|
|
@sql |
|
|
|
def delete_policy(self, short_channel_id, node_id): |
|
|
|
self.DBSession.execute(PolicyBase.__table__.delete().values(policy)) |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute("""DELETE FROM policy WHERE key=?""", (key,)) |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_channel(self, channel_info): |
|
|
|
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info)) |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info)) |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_node(self, node_info): |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info)) |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_node_address(self, node_id, peer, now): |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now)) |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_node_addresses(self, node_id, node_addresses): |
|
|
|
c = self.conn.cursor() |
|
|
|
for addr in node_addresses: |
|
|
|
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port)) |
|
|
|
r = c.fetchall() |
|
|
|
if r == []: |
|
|
|
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0)) |
|
|
|
|
|
|
|
def verify_channel_update(self, payload): |
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
@ -418,7 +422,6 @@ class ChannelDB(SqlDB): |
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
old_addr = None |
|
|
|
new_nodes = {} |
|
|
|
new_addresses = {} |
|
|
|
for msg_payload in msg_payloads: |
|
|
|
try: |
|
|
|
node_info, node_addresses = NodeInfo.from_msg(msg_payload) |
|
|
@ -445,17 +448,6 @@ class ChannelDB(SqlDB): |
|
|
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) |
|
|
|
self.update_counts() |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_node_addresses(self, node_if, node_addresses): |
|
|
|
for new_addr in node_addresses: |
|
|
|
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() |
|
|
|
if not old_addr: |
|
|
|
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr)) |
|
|
|
|
|
|
|
@sql |
|
|
|
def save_node(self, node_info): |
|
|
|
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info)) |
|
|
|
|
|
|
|
def get_routing_policy_for_channel(self, start_node_id: bytes, |
|
|
|
short_channel_id: bytes) -> Optional[bytes]: |
|
|
|
if not start_node_id or not short_channel_id: return None |
|
|
@ -506,12 +498,18 @@ class ChannelDB(SqlDB): |
|
|
|
@sql |
|
|
|
@profiler |
|
|
|
def load_data(self): |
|
|
|
for x in self.DBSession.query(AddressBase).all(): |
|
|
|
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0))) |
|
|
|
for x in self.DBSession.query(ChannelInfoBase).all(): |
|
|
|
self._channels[x.short_channel_id] = x.to_nametuple() |
|
|
|
for x in self.DBSession.query(PolicyBase).filter_by().all(): |
|
|
|
p = x.to_nametuple() |
|
|
|
c = self.conn.cursor() |
|
|
|
c.execute("""SELECT * FROM address""") |
|
|
|
for x in c: |
|
|
|
node_id, host, port, timestamp = x |
|
|
|
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0))) |
|
|
|
c.execute("""SELECT * FROM channel_info""") |
|
|
|
for x in c: |
|
|
|
ci = ChannelInfo(*x) |
|
|
|
self._channels[ci.short_channel_id] = ci |
|
|
|
c.execute("""SELECT * FROM policy""") |
|
|
|
for x in c: |
|
|
|
p = Policy(*x) |
|
|
|
self._policies[(p.start_node, p.short_channel_id)] = p |
|
|
|
for channel_info in self._channels.values(): |
|
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
|
|