You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

568 lines
22 KiB

# -*- coding: utf-8 -*-
#
# Electrum - lightweight Bitcoin client
# Copyright (C) 2018 The Electrum developers
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from datetime import datetime
import time
import random
import queue
import os
import json
import threading
import concurrent
from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
import binascii
import base64
from .sql_db import SqlDB, sql
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
from .logging import Logger
from .storage import JsonDB
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .crypto import sha256d
from . import ecc
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
NotFoundChanAnnouncementForUpdate)
from .lnverifier import verify_sig_for_channel_update
from .lnmsg import encode_msg
if TYPE_CHECKING:
from .lnchannel import Channel
from .network import Network
class UnknownEvenFeatureBits(Exception): pass
def validate_features(features : int):
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class ChannelInfo(NamedTuple):
short_channel_id: bytes
node1_id: bytes
node2_id: bytes
capacity_sat: int
@staticmethod
def from_msg(payload):
features = int.from_bytes(payload['features'], 'big')
validate_features(features)
channel_id = payload['short_channel_id']
node_id_1 = payload['node_id_1']
node_id_2 = payload['node_id_2']
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
capacity_sat = None
return ChannelInfo(
short_channel_id = channel_id,
node1_id = node_id_1,
node2_id = node_id_2,
capacity_sat = capacity_sat)
class Policy(NamedTuple):
key: bytes
cltv_expiry_delta: int
htlc_minimum_msat: int
htlc_maximum_msat: int
fee_base_msat: int
fee_proportional_millionths: int
channel_flags: int
timestamp: int
@staticmethod
def from_msg(payload):
return Policy(
key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"),
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"),
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
timestamp = int.from_bytes(payload['timestamp'], "big")
)
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
@property
def short_channel_id(self):
return self.key[0:8]
@property
def start_node(self):
return self.key[8:]
class NodeInfo(NamedTuple):
node_id: bytes
features: int
timestamp: int
alias: str
@staticmethod
def from_msg(payload):
node_id = payload['node_id']
features = int.from_bytes(payload['features'], "big")
validate_features(features)
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
alias = payload['alias'].rstrip(b'\x00')
timestamp = int.from_bytes(payload['timestamp'], "big")
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
@staticmethod
def parse_addresses_field(addresses_field):
buf = addresses_field
def read(n):
nonlocal buf
data, buf = buf[0:n], buf[n:]
return data
addresses = []
while buf:
atype = ord(read(1))
if atype == 0:
pass
elif atype == 1: # IPv4
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv4_addr) and port != 0:
addresses.append((ipv4_addr, port))
elif atype == 2: # IPv6
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
ipv6_addr = ipv6_addr.decode('ascii')
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv6_addr) and port != 0:
addresses.append((ipv6_addr, port))
elif atype == 3: # onion v2
host = base64.b32encode(read(10)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
elif atype == 4: # onion v3
host = base64.b32encode(read(35)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
else:
# unknown address type
# we don't know how long it is -> have to escape
# if there are other addresses we could have parsed later, they are lost.
break
return addresses
class Address(NamedTuple):
node_id: bytes
host: str
port: int
last_connected_date: int
class CategorizedChannelUpdates(NamedTuple):
orphaned: List # no channel announcement for channel update
expired: List # update older than two weeks
deprecated: List # update older than database entry
good: List # good updates
to_delete: List # database entries to delete
create_channel_info = """
CREATE TABLE IF NOT EXISTS channel_info (
short_channel_id VARCHAR(64),
node1_id VARCHAR(66),
node2_id VARCHAR(66),
capacity_sat INTEGER,
PRIMARY KEY(short_channel_id)
)"""
create_policy = """
CREATE TABLE IF NOT EXISTS policy (
key VARCHAR(66),
cltv_expiry_delta INTEGER NOT NULL,
htlc_minimum_msat INTEGER NOT NULL,
htlc_maximum_msat INTEGER,
fee_base_msat INTEGER NOT NULL,
fee_proportional_millionths INTEGER NOT NULL,
channel_flags INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
PRIMARY KEY(key)
)"""
create_address = """
CREATE TABLE IF NOT EXISTS address (
node_id VARCHAR(66),
host STRING(256),
port INTEGER NOT NULL,
timestamp INTEGER,
PRIMARY KEY(node_id, host, port)
)"""
create_node_info = """
CREATE TABLE IF NOT EXISTS node_info (
node_id VARCHAR(66),
features INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
alias STRING(64),
PRIMARY KEY(node_id)
)"""
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, commit_interval=100)
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data
self._channels = {} # type: Dict[bytes, ChannelInfo]
self._policies = {}
self._nodes = {}
self._addresses = defaultdict(set)
self._channels_for_node = defaultdict(set)
def update_counts(self):
self.num_channels = len(self._channels)
self.num_policies = len(self._policies)
self.num_nodes = len(self._nodes)
def get_channel_ids(self):
return set(self._channels.keys())
def add_recent_peer(self, peer: LNPeerAddr):
now = int(time.time())
node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now))
self.save_node_address(node_id, peer, now)
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids
return random.sample(unshuffled, min(200, len(unshuffled)))
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
r = self._addresses.get(node_id)
if not r:
return None
addr = sorted(list(r), key=lambda x: x[2])[0]
host, port, timestamp = addr
return LNPeerAddr(host, port, node_id)
def get_recent_peers(self):
r = [self.get_last_good_address(x) for x in self._addresses.keys()]
r = r[-self.NUM_MAX_RECENT_PEERS:]
return r
def add_channel_announcement(self, msg_payloads, trusted=True):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
added = 0
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
if short_channel_id in self._channels:
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
continue
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
self.logger.info("unknown feature bits")
continue
added += 1
self._channels[short_channel_id] = channel_info
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.save_channel(channel_info)
if not trusted:
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, msg)
self.update_counts()
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
def print_change(self, old_policy, new_policy):
# print what changed between policies
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
if old_policy.fee_base_msat != new_policy.fee_base_msat:
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
if old_policy.channel_flags != new_policy.channel_flags:
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
def add_channel_updates(self, payloads, max_age=None, verify=True) -> CategorizedChannelUpdates:
orphaned = []
expired = []
deprecated = []
good = []
to_delete = []
# filter orphaned and expired first
known = []
now = int(time.time())
for payload in payloads:
short_channel_id = payload['short_channel_id']
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(payload)
continue
channel_info = self._channels.get(short_channel_id)
if not channel_info:
orphaned.append(payload)
continue
flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
payload['start_node'] = start_node
known.append(payload)
# compare updates to existing database entries
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node']
short_channel_id = payload['short_channel_id']
key = (start_node, short_channel_id)
old_policy = self._policies.get(key)
if old_policy and timestamp <= old_policy.timestamp:
deprecated.append(payload)
continue
good.append(payload)
if verify:
self.verify_channel_update(payload)
policy = Policy.from_msg(payload)
self._policies[key] = policy
self.save_policy(policy)
#
self.update_counts()
return CategorizedChannelUpdates(
orphaned=orphaned,
expired=expired,
deprecated=deprecated,
good=good,
to_delete=to_delete,
)
def add_channel_update(self, payload):
categorized_chan_upds = self.add_channel_updates([payload], verify=False)
assert len(categorized_chan_upds.good) == 1
def create_database(self):
c = self.conn.cursor()
c.execute(create_node_info)
c.execute(create_address)
c.execute(create_policy)
c.execute(create_channel_info)
self.conn.commit()
@sql
def save_policy(self, policy):
c = self.conn.cursor()
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy))
@sql
def delete_policy(self, node_id, short_channel_id):
key = short_channel_id + node_id
c = self.conn.cursor()
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
@sql
def save_channel(self, channel_info):
c = self.conn.cursor()
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
@sql
def delete_channel(self, short_channel_id):
c = self.conn.cursor()
c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
@sql
def save_node(self, node_info):
c = self.conn.cursor()
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
@sql
def save_node_address(self, node_id, peer, now):
c = self.conn.cursor()
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
@sql
def save_node_addresses(self, node_id, node_addresses):
c = self.conn.cursor()
for addr in node_addresses:
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
r = c.fetchall()
if r == []:
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
old_addr = None
new_nodes = {}
for msg_payload in msg_payloads:
try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
node_id = node_info.node_id
# Ignore node if it has no associated channel (DoS protection)
if node_id not in self._channels_for_node:
#self.logger.info('ignoring orphan node_announcement')
continue
node = self._nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
node = new_nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
# save
self._nodes[node_id] = node_info
self.save_node(node_info)
for addr in node_addresses:
self._addresses[node_id].add((addr.host, addr.port, 0))
self.save_node_addresses(node_id, node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return self.get_policy_for_node(short_channel_id, start_node_id)
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
return Policy.from_msg(msg) # won't actually be written to DB
def get_old_policies(self, delta):
now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
def prune_old_policies(self, delta):
l = self.get_old_policies(delta)
for k in l:
self._policies.pop(k)
self.delete_policy(*k)
if l:
self.logger.info(f'Deleting {len(l)} old policies')
def get_orphaned_channels(self):
ids = set(x[1] for x in self._policies.keys())
return list(x for x in self._channels.keys() if x not in ids)
def prune_orphaned_channels(self):
l = self.get_orphaned_channels()
for short_channel_id in l:
self.remove_channel(short_channel_id)
self.delete_channel(short_channel_id)
self.update_counts()
if l:
self.logger.info(f'Deleting {len(l)} orphaned channels')
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id):
channel_info = self._channels.pop(short_channel_id, None)
if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
def get_node_addresses(self, node_id):
return self._addresses.get(node_id)
@sql
@profiler
def load_data(self):
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
node_id, host, port, timestamp = x
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
c.execute("""SELECT * FROM channel_info""")
for x in c:
ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM node_info""")
for x in c:
ni = NodeInfo(*x)
self._nodes[ni.node_id] = ni
c.execute("""SELECT * FROM policy""")
for x in c:
p = Policy(*x)
self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values():
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
self.update_counts()
self.count_incomplete_channels()
def count_incomplete_channels(self):
out = set()
for short_channel_id, ci in self._channels.items():
p1 = self.get_policy_for_node(short_channel_id, ci.node1_id)
p2 = self.get_policy_for_node(short_channel_id, ci.node2_id)
if p1 is None or p2 is not None:
out.add(short_channel_id)
self.logger.info(f'semi-orphaned: {len(out)}')
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
return self._policies.get((node_id, short_channel_id))
def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
return self._channels.get(channel_id)
def get_channels_for_node(self, node_id) -> Set[bytes]:
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node.get(node_id) or set()