|
|
@ -12,7 +12,7 @@ import queue |
|
|
|
import traceback |
|
|
|
import itertools |
|
|
|
import json |
|
|
|
from collections import OrderedDict |
|
|
|
from collections import OrderedDict, defaultdict |
|
|
|
import asyncio |
|
|
|
import sys |
|
|
|
import os |
|
|
@ -30,7 +30,7 @@ from .bitcoin import (public_key_from_private_key, ser_to_point, point_to_ser, |
|
|
|
from . import bitcoin |
|
|
|
from . import constants |
|
|
|
from . import transaction |
|
|
|
from .util import PrintError, bh2u, print_error, bfh |
|
|
|
from .util import PrintError, bh2u, print_error, bfh, profiler |
|
|
|
from .transaction import opcodes, Transaction |
|
|
|
|
|
|
|
# hardcoded nodes |
|
|
@ -341,6 +341,7 @@ class Peer(PrintError): |
|
|
|
self.channels = {} # received channel announcements |
|
|
|
self.channel_u_origin = {} |
|
|
|
self.channel_u_final = {} |
|
|
|
self.graph_of_payment_channels = defaultdict(set) # node -> short_channel_id |
|
|
|
|
|
|
|
def diagnostic_name(self): |
|
|
|
return self.host |
|
|
@ -509,7 +510,7 @@ class Peer(PrintError): |
|
|
|
flags = int.from_bytes(payload['flags'], byteorder="big") |
|
|
|
direction = bool(flags & 1) |
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
if direction: |
|
|
|
if direction == 0: |
|
|
|
self.channel_u_origin[short_channel_id] = payload |
|
|
|
else: |
|
|
|
self.channel_u_final[short_channel_id] = payload |
|
|
@ -519,10 +520,86 @@ class Peer(PrintError): |
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
self.print_error('channel announcement', binascii.hexlify(short_channel_id)) |
|
|
|
self.channels[short_channel_id] = payload |
|
|
|
self.add_channel_to_graph(payload) |
|
|
|
|
|
|
|
def add_channel_to_graph(self, payload): |
|
|
|
node1 = payload['node_id_1'] |
|
|
|
node2 = payload['node_id_2'] |
|
|
|
channel_id = payload['short_channel_id'] |
|
|
|
self.graph_of_payment_channels[node1].add(channel_id) |
|
|
|
self.graph_of_payment_channels[node2].add(channel_id) |
|
|
|
|
|
|
|
#def open_channel(self, funding_sat, push_msat): |
|
|
|
# self.send_message(gen_msg('open_channel', funding_satoshis=funding_sat, push_msat=push_msat)) |
|
|
|
|
|
|
|
@profiler |
|
|
|
def find_route_for_payment(self, from_node_id, to_node_id, amount_msat=None): |
|
|
|
"""Return a route between from_node_id and to_node_id. |
|
|
|
|
|
|
|
Returns a list of (node_id, short_channel_id) representing a path. |
|
|
|
To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1] |
|
|
|
""" |
|
|
|
# TODO find multiple paths?? |
|
|
|
|
|
|
|
def edge_cost(short_channel_id, direction): |
|
|
|
"""Heuristic cost of going through a channel. |
|
|
|
direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 |
|
|
|
""" |
|
|
|
channel_updates = self.channel_u_origin if direction == 0 else self.channel_u_final |
|
|
|
try: |
|
|
|
cltv_expiry_delta = channel_updates[short_channel_id]['cltv_expiry_delta'] |
|
|
|
htlc_minimum_msat = channel_updates[short_channel_id]['htlc_minimum_msat'] |
|
|
|
fee_base_msat = channel_updates[short_channel_id]['fee_base_msat'] |
|
|
|
fee_proportional_millionths = channel_updates[short_channel_id]['fee_proportional_millionths'] |
|
|
|
except KeyError: |
|
|
|
return float('inf') # can't use this channel |
|
|
|
if amount_msat is not None and amount_msat < htlc_minimum_msat: |
|
|
|
return float('inf') # can't use this channel |
|
|
|
amt = amount_msat or 50000 * 1000 # guess for typical payment amount |
|
|
|
fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000 |
|
|
|
# TODO revise |
|
|
|
# paying 10 more satoshis ~ waiting one more block |
|
|
|
fee_cost = fee_msat / 1000 / 10 |
|
|
|
cltv_cost = cltv_expiry_delta |
|
|
|
return cltv_cost + fee_cost + 1 |
|
|
|
|
|
|
|
# run Dijkstra |
|
|
|
distance_from_start = defaultdict(lambda: float('inf')) |
|
|
|
distance_from_start[from_node_id] = 0 |
|
|
|
prev_node = {} |
|
|
|
nodes_to_explore = queue.PriorityQueue() |
|
|
|
nodes_to_explore.put((0, from_node_id)) |
|
|
|
|
|
|
|
while nodes_to_explore.qsize() > 0: |
|
|
|
dist_to_cur_node, cur_node = nodes_to_explore.get() |
|
|
|
if cur_node == to_node_id: |
|
|
|
break |
|
|
|
if dist_to_cur_node != distance_from_start[cur_node]: |
|
|
|
# queue.PriorityQueue does not implement decrease_priority, |
|
|
|
# so instead of decreasing priorities, we add items again into the queue. |
|
|
|
# so there are duplicates in the queue, that we discard now: |
|
|
|
continue |
|
|
|
for edge in self.graph_of_payment_channels[cur_node]: |
|
|
|
node1 = self.channels[edge]['node_id_1'] |
|
|
|
node2 = self.channels[edge]['node_id_2'] |
|
|
|
neighbour, direction = (node1, 1) if node1 != cur_node else (node2, 0) |
|
|
|
alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost(edge, direction) |
|
|
|
if alt_dist_to_neighbour < distance_from_start[neighbour]: |
|
|
|
distance_from_start[neighbour] = alt_dist_to_neighbour |
|
|
|
prev_node[neighbour] = cur_node, edge |
|
|
|
nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) |
|
|
|
else: |
|
|
|
return None # no path found |
|
|
|
|
|
|
|
# backtrack from end to start |
|
|
|
cur_node = to_node_id |
|
|
|
path = [(cur_node, None)] |
|
|
|
while cur_node != from_node_id: |
|
|
|
cur_node, edge_taken = prev_node[cur_node] |
|
|
|
path += [(cur_node, edge_taken)] |
|
|
|
path.reverse() |
|
|
|
return path |
|
|
|
|
|
|
|
@aiosafe |
|
|
|
async def main_loop(self): |
|
|
|
self.reader, self.writer = await asyncio.open_connection(self.host, self.port) |
|
|
|