You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

700 lines
32 KiB

# -*- coding: utf-8 -*-
#
# Electrum - lightweight Bitcoin client
# Copyright (C) 2018 The Electrum developers
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import datetime
import random
import queue
import os
import json
import threading
import concurrent
from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
import binascii
import base64
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .storage import JsonDB
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .crypto import sha256d
from . import ecc
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
NotFoundChanAnnouncementForUpdate)
from .lnmsg import encode_msg
if TYPE_CHECKING:
from .lnchannel import Channel
from .network import Network
class UnknownEvenFeatureBits(Exception): pass
class NoChannelPolicy(Exception):
def __init__(self, short_channel_id: bytes):
super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
def validate_features(features : int):
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
Base = declarative_base()
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class ChannelInfo(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True)
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
capacity_sat = Column(Integer)
msg_payload_hex = Column(String(1024), nullable=False)
trusted = Column(Boolean, nullable=False)
@staticmethod
def from_msg(channel_announcement_payload):
features = int.from_bytes(channel_announcement_payload['features'], 'big')
validate_features(features)
channel_id = channel_announcement_payload['short_channel_id'].hex()
node_id_1 = channel_announcement_payload['node_id_1'].hex()
node_id_2 = channel_announcement_payload['node_id_2'].hex()
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
capacity_sat = None
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
trusted = False)
@property
def msg_payload(self):
return bytes.fromhex(self.msg_payload_hex)
class Policy(Base):
__tablename__ = 'policy'
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
cltv_expiry_delta = Column(Integer, nullable=False)
htlc_minimum_msat = Column(Integer, nullable=False)
htlc_maximum_msat = Column(Integer)
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(DateTime, nullable=False)
@staticmethod
def from_msg(channel_update_payload, start_node, short_channel_id):
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
fee_base_msat = channel_update_payload['fee_base_msat']
fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
channel_flags = channel_update_payload['channel_flags']
timestamp = channel_update_payload['timestamp']
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big")
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
cltv_expiry_delta=cltv_expiry_delta,
htlc_minimum_msat=htlc_minimum_msat,
fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths,
channel_flags=channel_flags,
timestamp=timestamp,
htlc_maximum_msat=htlc_maximum_msat)
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
class NodeInfo(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
@staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False):
node_id = node_announcement_payload['node_id'].hex()
features = int.from_bytes(node_announcement_payload['features'], "big")
validate_features(features)
if not addresses_already_parsed:
addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
else:
addresses = node_announcement_payload['addresses']
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
@staticmethod
def parse_addresses_field(addresses_field):
buf = addresses_field
def read(n):
nonlocal buf
data, buf = buf[0:n], buf[n:]
return data
addresses = []
while buf:
atype = ord(read(1))
if atype == 0:
pass
elif atype == 1: # IPv4
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv4_addr) and port != 0:
addresses.append((ipv4_addr, port))
elif atype == 2: # IPv6
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
ipv6_addr = ipv6_addr.decode('ascii')
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv6_addr) and port != 0:
addresses.append((ipv6_addr, port))
elif atype == 3: # onion v2
host = base64.b32encode(read(10)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
elif atype == 4: # onion v3
host = base64.b32encode(read(35)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
else:
# unknown address type
# we don't know how long it is -> have to escape
# if there are other addresses we could have parsed later, they are lost.
break
return addresses
class Address(Base):
__tablename__ = 'address'
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True)
last_connected_date = Column(DateTime(), nullable=False)
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base)
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.update_counts()
@sql
def update_counts(self):
self._update_counts()
def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count()
self.num_nodes = self.DBSession.query(NodeInfo).count()
@sql
def add_recent_peer(self, peer: LNPeerAddr):
addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None:
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
else:
addr.last_connected_date = datetime.datetime.now()
self.DBSession.add(addr)
self.DBSession.commit()
@sql
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = self.DBSession \
.query(NodeInfo) \
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \
.all()
return random.sample(unshuffled, len(unshuffled))
@sql
def nodes_get(self, node_id):
return self.DBSession \
.query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \
.one_or_none()
@sql
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
adr_db = self.DBSession \
.query(Address) \
.filter_by(node_id = node_id.hex()) \
.order_by(Address.last_connected_date.desc()) \
.one_or_none()
if not adr_db:
return None
return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
@sql
def get_recent_peers(self):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \
.query(Address) \
.select_from(NodeInfo) \
.order_by(Address.last_connected_date.desc()) \
.limit(self.NUM_MAX_RECENT_PEERS)]
@sql
def get_channel_info(self, channel_id: bytes):
return self._chan_query_for_id(channel_id).one_or_none()
@sql
def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints."""
condition = or_(
ChannelInfo.node1_id == node_id.hex(),
ChannelInfo.node2_id == node_id.hex())
rows = self.DBSession.query(ChannelInfo).filter(condition).all()
return [bytes.fromhex(x.short_channel_id) for x in rows]
@sql
def missing_short_chan_ids(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
if chan_ids_from_policy:
return chan_ids_from_policy
# fetch channels for node_ids missing in node_info. that will also give us node_announcement
expr = not_(ChannelInfo.node1_id.in_(self.DBSession.query(NodeInfo.node_id)))
chan_ids_from_id1 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
if chan_ids_from_id1:
return chan_ids_from_id1
expr = not_(ChannelInfo.node2_id.in_(self.DBSession.query(NodeInfo.node_id)))
chan_ids_from_id2 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
if chan_ids_from_id2:
return chan_ids_from_id2
return set()
@sql
def add_verified_channel_info(self, short_id, capacity):
# called from lnchannelverifier
channel_info = self._chan_query_for_id(short_id).one_or_none()
channel_info.trusted = True
channel_info.capacity = capacity
self.DBSession.commit()
@sql
@profiler
def on_channel_announcement(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
continue
channel_info.trusted = trusted
self.DBSession.add(channel_info)
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
self.DBSession.commit()
self._update_counts()
self.network.trigger_callback('ln_status')
@sql
@profiler
def on_channel_update(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
continue
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
continue
self._update_channel_info(channel_info, msg_payload, trusted=trusted)
self.DBSession.commit()
@sql
@profiler
def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
addresses = self.DBSession.query(Address).all()
have_addr = {}
for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr
nodes = self.DBSession.query(NodeInfo).all()
timestamps = {}
for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
old_addr = None
for msg_payload in msg_payloads:
pubkey = msg_payload['node_id']
signature = msg_payload['signature']
h = sha256d(msg_payload['raw'][66:])
if not ecc.verify_signature(pubkey, signature, h):
continue
try:
new_node_info, addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
continue # ignore
self.DBSession.add(new_node_info)
for new_addr in addresses:
key = (new_addr.node_id, new_addr.host, new_addr.port)
old_addr = have_addr.get(key)
if old_addr:
# since old_addr is embedded in have_addr,
# it will still live when commmit is called
old_addr.last_connected_date = new_addr.last_connected_date
del new_addr
else:
self.DBSession.add(new_addr)
have_addr[key] = new_addr
# TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here,
# to mitigate DOS. but race condition: the channels we have for this
# node, might be under verification in self.ca_verifier, what then?
del nodes, addresses
if old_addr:
del old_addr
self.DBSession.commit()
self._update_counts()
self.network.trigger_callback('ln_status')
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return self.get_policy_for_node(channel_info, start_node_id)
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql
def remove_channel(self, short_channel_id):
self._chan_query_for_id(short_channel_id).delete('evaluate')
self.DBSession.commit()
def _chan_query_for_id(self, short_channel_id) -> Query:
return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
def print_graph(self, full_ids=False):
# used for debugging.
# FIXME there is a race here - iterables could change size from another thread
def other_node_id(node_id, channel_id):
channel_info = self.get_channel_info(channel_id)
if node_id == channel_info.node1_id:
other = channel_info.node2_id
else:
other = channel_info.node1_id
return other if full_ids else other[-4:]
self.print_msg('nodes')
for node in self.DBSession.query(NodeInfo).all():
self.print_msg(node)
self.print_msg('channels')
for channel_info in self.DBSession.query(ChannelInfo).all():
node1 = channel_info.node1_id
node2 = channel_info.node2_id
direction1 = self.get_policy_for_node(channel_info, node1) is not None
direction2 = self.get_policy_for_node(channel_info, node2) is not None
if direction1 and direction2:
direction = 'both'
elif direction1:
direction = 'forward'
elif direction2:
direction = 'backward'
else:
direction = 'none'
self.print_msg('{}: {}, {}, {}'
.format(bh2u(short_channel_id),
bh2u(node1) if full_ids else bh2u(node1[-4:]),
bh2u(node2) if full_ids else bh2u(node2[-4:]),
direction))
def _update_channel_info(self, channel_info, msg: dict, trusted=False):
assert channel_info.short_channel_id == msg['short_channel_id'].hex()
flags = int.from_bytes(msg['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id)
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none()
if not old_policy:
self.DBSession.add(new_policy)
return
if old_policy.timestamp >= new_policy.timestamp:
return # ignore
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
return # ignore
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
old_policy.fee_base_msat = new_policy.fee_base_msat
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
old_policy.channel_flags = new_policy.channel_flags
old_policy.timestamp = new_policy.timestamp
@sql
def get_policy_for_node(self, channel_info, node) -> Optional['Policy']:
"""
raises when initiator/non-initiator both unequal node
"""
if node.hex() not in (channel_info.node1_id, channel_info.node2_id):
raise Exception("the given node is not a party in this channel")
n1 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node1_id).one_or_none()
if n1:
return n1
n2 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node2_id).one_or_none()
return n2
@sql
def get_node_addresses(self, node_info):
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
('short_channel_id', bytes),
('fee_base_msat', int),
('fee_proportional_millionths', int),
('cltv_expiry_delta', int)])):
"""if you travel through short_channel_id, you will reach node_id"""
def fee_for_edge(self, amount_msat: int) -> int:
return self.fee_base_msat \
+ (amount_msat * self.fee_proportional_millionths // 1_000_000)
@classmethod
def from_channel_policy(cls, channel_policy: 'Policy',
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
assert type(short_channel_id) is bytes
assert type(end_node) is bytes
return RouteEdge(end_node,
short_channel_id,
channel_policy.fee_base_msat,
channel_policy.fee_proportional_millionths,
channel_policy.cltv_expiry_delta)
def is_sane_to_use(self, amount_msat: int) -> bool:
# TODO revise ad-hoc heuristics
# cltv cannot be more than 2 weeks
if self.cltv_expiry_delta > 14 * 144: return False
total_fee = self.fee_for_edge(amount_msat)
# fees below 50 sat are fine
if total_fee > 50_000:
# fee cannot be higher than amt
if total_fee > amount_msat: return False
# fee cannot be higher than 5000 sat
if total_fee > 5_000_000: return False
# unless amt is tiny, fee cannot be more than 10%
if amount_msat > 1_000_000 and total_fee > amount_msat/10: return False
return True
def is_route_sane_to_use(route: List[RouteEdge], invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
"""Run some sanity checks on the whole route, before attempting to use it.
called when we are paying; so e.g. lower cltv is better
"""
if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
return False
amt = invoice_amount_msat
cltv = min_final_cltv_expiry
for route_edge in reversed(route[1:]):
if not route_edge.is_sane_to_use(amt): return False
amt += route_edge.fee_for_edge(amt)
cltv += route_edge.cltv_expiry_delta
total_fee = amt - invoice_amount_msat
# TODO revise ad-hoc heuristics
# cltv cannot be more than 2 months
if cltv > 60 * 144: return False
# fees below 50 sat are fine
if total_fee > 50_000:
# fee cannot be higher than amt
if total_fee > invoice_amount_msat: return False
# fee cannot be higher than 5000 sat
if total_fee > 5_000_000: return False
# unless amt is tiny, fee cannot be more than 10%
if invoice_amount_msat > 1_000_000 and total_fee > invoice_amount_msat/10: return False
return True
class LNPathFinder(PrintError):
def __init__(self, channel_db: ChannelDB):
self.channel_db = channel_db
self.blacklist = set()
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False) -> Tuple[float, int]:
"""Heuristic cost of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat).
"""
channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo
if channel_info is None:
return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node)
if channel_policy is None: return float('inf'), 0
if channel_policy.is_disabled(): return float('inf'), 0
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
if payment_amt_msat < channel_policy.htlc_minimum_msat:
return float('inf'), 0 # payment amount too little
if channel_info.capacity_sat is not None and \
payment_amt_msat // 1000 > channel_info.capacity_sat:
return float('inf'), 0 # payment amount too large
if channel_policy.htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy.htlc_maximum_msat:
return float('inf'), 0 # payment amount too large
if not route_edge.is_sane_to_use(payment_amt_msat):
return float('inf'), 0 # thanks but no thanks
fee_msat = route_edge.fee_for_edge(payment_amt_msat) if not ignore_costs else 0
# TODO revise
# paying 10 more satoshis ~ waiting one more block
fee_cost = fee_msat / 1000 / 10
cltv_cost = route_edge.cltv_expiry_delta if not ignore_costs else 0
return cltv_cost + fee_cost + 1, fee_msat
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int,
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
"""Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path.
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
"""
assert type(nodeA) is bytes
assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int
if my_channels is None: my_channels = []
my_channels = {chan.short_channel_id: chan for chan in my_channels}
# FIXME paths cannot be longer than 21 edges (onion packet)...
# run Dijkstra
# The search is run in the REVERSE direction, from nodeB to nodeA,
# to properly calculate compound routing fees.
distance_from_start = defaultdict(lambda: float('inf'))
distance_from_start[nodeB] = 0
prev_node = {}
nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
def inspect_edge():
if edge_channel_id in my_channels:
if edge_startnode == nodeA: # payment outgoing, on our channel
if not my_channels[edge_channel_id].can_pay(amount_msat):
return
else: # payment incoming, on our channel. (funny business, cycle weirdness)
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
pass # TODO?
edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id,
start_node=edge_startnode,
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA))
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
prev_node[edge_startnode] = edge_endnode, edge_channel_id
amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
# main loop of search
while nodes_to_explore.qsize() > 0:
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
if edge_endnode == nodeA:
break
if dist_to_edge_endnode != distance_from_start[edge_endnode]:
# queue.PriorityQueue does not implement decrease_priority,
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
assert type(edge_channel_id) is bytes
if edge_channel_id in self.blacklist: continue
channel_info = self.channel_db.get_channel_info(edge_channel_id)
edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id)
inspect_edge()
else:
return None # no path found
# backtrack from search_end (nodeA) to search_start (nodeB)
edge_startnode = nodeA
path = []
while edge_startnode != nodeB:
edge_endnode, edge_taken = prev_node[edge_startnode]
path += [(edge_endnode, edge_taken)]
edge_startnode = edge_endnode
return path
def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:
assert type(from_node_id) is bytes
if path is None:
raise Exception('cannot create route from None path')
route = []
prev_node_id = from_node_id
for node_id, short_channel_id in path:
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
prev_node_id = node_id
return route