|
|
@ -419,10 +419,8 @@ class Peer(PrintError): |
|
|
|
self.localfeatures = (0x08 if request_initial_sync else 0) |
|
|
|
# view of the network |
|
|
|
self.nodes = {} # received node announcements |
|
|
|
self.channels = {} # received channel announcements |
|
|
|
self.channel_u_origin = {} |
|
|
|
self.channel_u_final = {} |
|
|
|
self.graph_of_payment_channels = defaultdict(set) # node -> short_channel_id |
|
|
|
self.channel_db = ChannelDB() |
|
|
|
self.path_finder = LNPathFinder(self.channel_db) |
|
|
|
|
|
|
|
def diagnostic_name(self): |
|
|
|
return self.host |
|
|
@ -541,8 +539,8 @@ class Peer(PrintError): |
|
|
|
def on_funding_signed(self, payload): |
|
|
|
sig = payload['signature'] |
|
|
|
channel_id = payload['channel_id'] |
|
|
|
tx = self.channels[channel_id] |
|
|
|
self.network.broadcast(tx) |
|
|
|
#tx = self.channels[channel_id] # FIXME |
|
|
|
#self.network.broadcast(tx) |
|
|
|
|
|
|
|
def on_funding_signed(self, payload): |
|
|
|
self.funding_signed[payload["temporary_channel_id"]].set_result(payload) |
|
|
@ -588,99 +586,14 @@ class Peer(PrintError): |
|
|
|
pass |
|
|
|
|
|
|
|
def on_channel_update(self, payload): |
|
|
|
flags = int.from_bytes(payload['flags'], byteorder="big") |
|
|
|
direction = bool(flags & 1) |
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
if direction == 0: |
|
|
|
self.channel_u_origin[short_channel_id] = payload |
|
|
|
else: |
|
|
|
self.channel_u_final[short_channel_id] = payload |
|
|
|
self.print_error('channel update', binascii.hexlify(short_channel_id), flags) |
|
|
|
self.channel_db.on_channel_update(payload) |
|
|
|
|
|
|
|
def on_channel_announcement(self, payload): |
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
self.print_error('channel announcement', binascii.hexlify(short_channel_id)) |
|
|
|
self.channels[short_channel_id] = payload |
|
|
|
self.add_channel_to_graph(payload) |
|
|
|
|
|
|
|
def add_channel_to_graph(self, payload): |
|
|
|
node1 = payload['node_id_1'] |
|
|
|
node2 = payload['node_id_2'] |
|
|
|
channel_id = payload['short_channel_id'] |
|
|
|
self.graph_of_payment_channels[node1].add(channel_id) |
|
|
|
self.graph_of_payment_channels[node2].add(channel_id) |
|
|
|
self.channel_db.on_channel_announcement(payload) |
|
|
|
|
|
|
|
#def open_channel(self, funding_sat, push_msat): |
|
|
|
# self.send_message(gen_msg('open_channel', funding_satoshis=funding_sat, push_msat=push_msat)) |
|
|
|
|
|
|
|
@profiler |
|
|
|
def find_route_for_payment(self, from_node_id, to_node_id, amount_msat=None): |
|
|
|
"""Return a route between from_node_id and to_node_id. |
|
|
|
|
|
|
|
Returns a list of (node_id, short_channel_id) representing a path. |
|
|
|
To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1] |
|
|
|
""" |
|
|
|
# TODO find multiple paths?? |
|
|
|
|
|
|
|
def edge_cost(short_channel_id, direction): |
|
|
|
"""Heuristic cost of going through a channel. |
|
|
|
direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 |
|
|
|
""" |
|
|
|
channel_updates = self.channel_u_origin if direction == 0 else self.channel_u_final |
|
|
|
try: |
|
|
|
cltv_expiry_delta = channel_updates[short_channel_id]['cltv_expiry_delta'] |
|
|
|
htlc_minimum_msat = channel_updates[short_channel_id]['htlc_minimum_msat'] |
|
|
|
fee_base_msat = channel_updates[short_channel_id]['fee_base_msat'] |
|
|
|
fee_proportional_millionths = channel_updates[short_channel_id]['fee_proportional_millionths'] |
|
|
|
except KeyError: |
|
|
|
return float('inf') # can't use this channel |
|
|
|
if amount_msat is not None and amount_msat < htlc_minimum_msat: |
|
|
|
return float('inf') # can't use this channel |
|
|
|
amt = amount_msat or 50000 * 1000 # guess for typical payment amount |
|
|
|
fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000 |
|
|
|
# TODO revise |
|
|
|
# paying 10 more satoshis ~ waiting one more block |
|
|
|
fee_cost = fee_msat / 1000 / 10 |
|
|
|
cltv_cost = cltv_expiry_delta |
|
|
|
return cltv_cost + fee_cost + 1 |
|
|
|
|
|
|
|
# run Dijkstra |
|
|
|
distance_from_start = defaultdict(lambda: float('inf')) |
|
|
|
distance_from_start[from_node_id] = 0 |
|
|
|
prev_node = {} |
|
|
|
nodes_to_explore = queue.PriorityQueue() |
|
|
|
nodes_to_explore.put((0, from_node_id)) |
|
|
|
|
|
|
|
while nodes_to_explore.qsize() > 0: |
|
|
|
dist_to_cur_node, cur_node = nodes_to_explore.get() |
|
|
|
if cur_node == to_node_id: |
|
|
|
break |
|
|
|
if dist_to_cur_node != distance_from_start[cur_node]: |
|
|
|
# queue.PriorityQueue does not implement decrease_priority, |
|
|
|
# so instead of decreasing priorities, we add items again into the queue. |
|
|
|
# so there are duplicates in the queue, that we discard now: |
|
|
|
continue |
|
|
|
for edge in self.graph_of_payment_channels[cur_node]: |
|
|
|
node1 = self.channels[edge]['node_id_1'] |
|
|
|
node2 = self.channels[edge]['node_id_2'] |
|
|
|
neighbour, direction = (node1, 1) if node1 != cur_node else (node2, 0) |
|
|
|
alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost(edge, direction) |
|
|
|
if alt_dist_to_neighbour < distance_from_start[neighbour]: |
|
|
|
distance_from_start[neighbour] = alt_dist_to_neighbour |
|
|
|
prev_node[neighbour] = cur_node, edge |
|
|
|
nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) |
|
|
|
else: |
|
|
|
return None # no path found |
|
|
|
|
|
|
|
# backtrack from end to start |
|
|
|
cur_node = to_node_id |
|
|
|
path = [(cur_node, None)] |
|
|
|
while cur_node != from_node_id: |
|
|
|
cur_node, edge_taken = prev_node[cur_node] |
|
|
|
path += [(cur_node, edge_taken)] |
|
|
|
path.reverse() |
|
|
|
return path |
|
|
|
|
|
|
|
@aiosafe |
|
|
|
async def main_loop(self): |
|
|
|
self.reader, self.writer = await asyncio.open_connection(self.host, self.port) |
|
|
@ -792,3 +705,165 @@ class LNWorker: |
|
|
|
# todo: get utxo from wallet |
|
|
|
# submit coro to asyncio main loop |
|
|
|
self.peer.open_channel() |
|
|
|
|
|
|
|
|
|
|
|
class ChannelInfo(PrintError): |
|
|
|
|
|
|
|
def __init__(self, channel_announcement_payload): |
|
|
|
self.channel_id = channel_announcement_payload['short_channel_id'] |
|
|
|
self.node_id_1 = channel_announcement_payload['node_id_1'] |
|
|
|
self.node_id_2 = channel_announcement_payload['node_id_2'] |
|
|
|
|
|
|
|
self.capacity_sat = None |
|
|
|
self.policy_node1 = None |
|
|
|
self.policy_node2 = None |
|
|
|
|
|
|
|
def set_capacity(self, capacity): |
|
|
|
# TODO call this after looking up UTXO for funding txn on chain |
|
|
|
self.capacity_sat = capacity |
|
|
|
|
|
|
|
def on_channel_update(self, msg_payload): |
|
|
|
assert self.channel_id == msg_payload['short_channel_id'] |
|
|
|
flags = int.from_bytes(msg_payload['flags'], byteorder="big") |
|
|
|
direction = bool(flags & 1) |
|
|
|
if direction == 0: |
|
|
|
self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload) |
|
|
|
else: |
|
|
|
self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload) |
|
|
|
self.print_error('channel update', binascii.hexlify(self.channel_id), flags) |
|
|
|
|
|
|
|
def get_policy_for_node(self, node_id): |
|
|
|
if node_id == self.node_id_1: |
|
|
|
return self.policy_node1 |
|
|
|
elif node_id == self.node_id_2: |
|
|
|
return self.policy_node2 |
|
|
|
else: |
|
|
|
raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id)) |
|
|
|
|
|
|
|
|
|
|
|
class ChannelInfoDirectedPolicy: |
|
|
|
|
|
|
|
def __init__(self, channel_update_payload): |
|
|
|
self.cltv_expiry_delta = channel_update_payload['cltv_expiry_delta'] |
|
|
|
self.htlc_minimum_msat = channel_update_payload['htlc_minimum_msat'] |
|
|
|
self.fee_base_msat = channel_update_payload['fee_base_msat'] |
|
|
|
self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths'] |
|
|
|
|
|
|
|
|
|
|
|
class ChannelDB(PrintError): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._id_to_channel_info = {} |
|
|
|
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id) |
|
|
|
|
|
|
|
def get_channel_info(self, channel_id): |
|
|
|
return self._id_to_channel_info.get(channel_id, None) |
|
|
|
|
|
|
|
def get_channels_for_node(self, node_id): |
|
|
|
"""Returns the set of channels that have node_id as one of the endpoints.""" |
|
|
|
return self._channels_for_node[node_id] |
|
|
|
|
|
|
|
def on_channel_announcement(self, msg_payload): |
|
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
|
|
self.print_error('channel announcement', binascii.hexlify(short_channel_id)) |
|
|
|
channel_info = ChannelInfo(msg_payload) |
|
|
|
self._id_to_channel_info[short_channel_id] = channel_info |
|
|
|
self._channels_for_node[channel_info.node_id_1].add(short_channel_id) |
|
|
|
self._channels_for_node[channel_info.node_id_2].add(short_channel_id) |
|
|
|
|
|
|
|
def on_channel_update(self, msg_payload): |
|
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
|
|
self._id_to_channel_info[short_channel_id].on_channel_update(msg_payload) |
|
|
|
|
|
|
|
def remove_channel(self, short_channel_id): |
|
|
|
try: |
|
|
|
channel_info = self._id_to_channel_info[short_channel_id] |
|
|
|
except KeyError: |
|
|
|
self.print_error('cannot find channel {}'.format(short_channel_id)) |
|
|
|
return |
|
|
|
self._id_to_channel_info.pop(short_channel_id, None) |
|
|
|
for node in (channel_info.node_id_1, channel_info.node_id_2): |
|
|
|
try: |
|
|
|
self._channels_for_node[node].remove(short_channel_id) |
|
|
|
except KeyError: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class LNPathFinder(PrintError): |
|
|
|
|
|
|
|
def __init__(self, channel_db): |
|
|
|
self.channel_db = channel_db |
|
|
|
|
|
|
|
def _edge_cost(self, short_channel_id, start_node, payment_amt_msat): |
|
|
|
"""Heuristic cost of going through a channel. |
|
|
|
direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 |
|
|
|
""" |
|
|
|
channel_info = self.channel_db.get_channel_info(short_channel_id) |
|
|
|
if channel_info is None: |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
channel_policy = channel_info.get_policy_for_node(start_node) |
|
|
|
cltv_expiry_delta = channel_policy.cltv_expiry_delta |
|
|
|
htlc_minimum_msat = channel_policy.htlc_minimum_msat |
|
|
|
fee_base_msat = channel_policy.fee_base_msat |
|
|
|
fee_proportional_millionths = channel_policy.fee_proportional_millionths |
|
|
|
if payment_amt_msat is not None: |
|
|
|
if payment_amt_msat < htlc_minimum_msat: |
|
|
|
return float('inf') # payment amount too little |
|
|
|
if channel_info.capacity_sat is not None and \ |
|
|
|
payment_amt_msat // 1000 > channel_info.capacity_sat: |
|
|
|
return float('inf') # payment amount too large |
|
|
|
amt = payment_amt_msat or 50000 * 1000 # guess for typical payment amount |
|
|
|
fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000 |
|
|
|
# TODO revise |
|
|
|
# paying 10 more satoshis ~ waiting one more block |
|
|
|
fee_cost = fee_msat / 1000 / 10 |
|
|
|
cltv_cost = cltv_expiry_delta |
|
|
|
return cltv_cost + fee_cost + 1 |
|
|
|
|
|
|
|
@profiler |
|
|
|
def find_path_for_payment(self, from_node_id, to_node_id, amount_msat=None): |
|
|
|
"""Return a path between from_node_id and to_node_id. |
|
|
|
|
|
|
|
Returns a list of (node_id, short_channel_id) representing a path. |
|
|
|
To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1] |
|
|
|
""" |
|
|
|
# TODO find multiple paths?? |
|
|
|
|
|
|
|
# run Dijkstra |
|
|
|
distance_from_start = defaultdict(lambda: float('inf')) |
|
|
|
distance_from_start[from_node_id] = 0 |
|
|
|
prev_node = {} |
|
|
|
nodes_to_explore = queue.PriorityQueue() |
|
|
|
nodes_to_explore.put((0, from_node_id)) |
|
|
|
|
|
|
|
while nodes_to_explore.qsize() > 0: |
|
|
|
dist_to_cur_node, cur_node = nodes_to_explore.get() |
|
|
|
if cur_node == to_node_id: |
|
|
|
break |
|
|
|
if dist_to_cur_node != distance_from_start[cur_node]: |
|
|
|
# queue.PriorityQueue does not implement decrease_priority, |
|
|
|
# so instead of decreasing priorities, we add items again into the queue. |
|
|
|
# so there are duplicates in the queue, that we discard now: |
|
|
|
continue |
|
|
|
for edge_channel_id in self.channel_db.get_channels_for_node(cur_node): |
|
|
|
channel_info = self.channel_db.get_channel_info(edge_channel_id) |
|
|
|
node1, node2 = channel_info.node_id_1, channel_info.node_id_2 |
|
|
|
neighbour = node2 if node1 == cur_node else node1 |
|
|
|
alt_dist_to_neighbour = distance_from_start[cur_node] \ |
|
|
|
+ self._edge_cost(edge_channel_id, cur_node, amount_msat) |
|
|
|
if alt_dist_to_neighbour < distance_from_start[neighbour]: |
|
|
|
distance_from_start[neighbour] = alt_dist_to_neighbour |
|
|
|
prev_node[neighbour] = cur_node, edge_channel_id |
|
|
|
nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) |
|
|
|
else: |
|
|
|
return None # no path found |
|
|
|
|
|
|
|
# backtrack from end to start |
|
|
|
cur_node = to_node_id |
|
|
|
path = [(cur_node, None)] |
|
|
|
while cur_node != from_node_id: |
|
|
|
cur_node, edge_taken = prev_node[cur_node] |
|
|
|
path += [(cur_node, edge_taken)] |
|
|
|
path.reverse() |
|
|
|
return path |
|
|
|