# -*- coding: utf-8 -*- # # Electrum - lightweight Bitcoin client # Copyright (C) 2018 The Electrum developers # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation files # (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, # publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import queue import traceback import sys import binascii import hashlib import hmac from collections import namedtuple, defaultdict from typing import Sequence, Union, Tuple from cryptography.hazmat.primitives.ciphers import Cipher, algorithms from cryptography.hazmat.backends import default_backend from . import bitcoin from . import ecc from . import crypto from .crypto import sha256 from .util import PrintError, bh2u, print_error, bfh, profiler, xor_bytes from . import lnbase class ChannelInfo(PrintError): def __init__(self, channel_announcement_payload): self.channel_id = channel_announcement_payload['short_channel_id'] self.node_id_1 = channel_announcement_payload['node_id_1'] self.node_id_2 = channel_announcement_payload['node_id_2'] assert type(self.node_id_1) is bytes assert type(self.node_id_2) is bytes assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2] self.capacity_sat = None self.policy_node1 = None self.policy_node2 = None def set_capacity(self, capacity): # TODO call this after looking up UTXO for funding txn on chain self.capacity_sat = capacity def on_channel_update(self, msg_payload): assert self.channel_id == msg_payload['short_channel_id'] flags = int.from_bytes(msg_payload['flags'], 'big') direction = flags & 1 if direction == 0: self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload) else: self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload) self.print_error('channel update', binascii.hexlify(self.channel_id).decode("ascii"), flags) def get_policy_for_node(self, node_id): if node_id == self.node_id_1: return self.policy_node1 elif node_id == self.node_id_2: return self.policy_node2 else: raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id)) class ChannelInfoDirectedPolicy: def __init__(self, channel_update_payload): self.cltv_expiry_delta = channel_update_payload['cltv_expiry_delta'] self.htlc_minimum_msat = channel_update_payload['htlc_minimum_msat'] self.fee_base_msat = channel_update_payload['fee_base_msat'] self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths'] self.cltv_expiry_delta = int.from_bytes(self.cltv_expiry_delta, "big") self.htlc_minimum_msat = int.from_bytes(self.htlc_minimum_msat, "big") self.fee_base_msat = int.from_bytes(self.fee_base_msat, "big") self.fee_proportional_millionths = int.from_bytes(self.fee_proportional_millionths, "big") class ChannelDB(PrintError): def __init__(self): self._id_to_channel_info = {} self._channels_for_node = defaultdict(set) # node -> set(short_channel_id) def get_channel_info(self, channel_id): return self._id_to_channel_info.get(channel_id, None) def get_channels_for_node(self, node_id): """Returns the set of channels that have node_id as one of the endpoints.""" return self._channels_for_node[node_id] def on_channel_announcement(self, msg_payload): short_channel_id = msg_payload['short_channel_id'] self.print_error('channel announcement', binascii.hexlify(short_channel_id).decode("ascii")) channel_info = ChannelInfo(msg_payload) self._id_to_channel_info[short_channel_id] = channel_info self._channels_for_node[channel_info.node_id_1].add(short_channel_id) self._channels_for_node[channel_info.node_id_2].add(short_channel_id) def on_channel_update(self, msg_payload): short_channel_id = msg_payload['short_channel_id'] try: channel_info = self._id_to_channel_info[short_channel_id] except KeyError: self.print_error("could not find", short_channel_id) else: channel_info.on_channel_update(msg_payload) def remove_channel(self, short_channel_id): try: channel_info = self._id_to_channel_info[short_channel_id] except KeyError: self.print_error('cannot find channel {}'.format(short_channel_id)) return self._id_to_channel_info.pop(short_channel_id, None) for node in (channel_info.node_id_1, channel_info.node_id_2): try: self._channels_for_node[node].remove(short_channel_id) except KeyError: pass class RouteEdge: def __init__(self, node_id: bytes, short_channel_id: bytes, channel_policy: ChannelInfoDirectedPolicy): # "if you travel through short_channel_id, you will reach node_id" self.node_id = node_id self.short_channel_id = short_channel_id self.channel_policy = channel_policy class LNPathFinder(PrintError): def __init__(self, channel_db): self.channel_db = channel_db self.blacklist = set() def _edge_cost(self, short_channel_id: bytes, start_node: bytes, payment_amt_msat: int, ignore_cltv=False) -> float: """Heuristic cost of going through a channel. direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 """ channel_info = self.channel_db.get_channel_info(short_channel_id) if channel_info is None: return float('inf') channel_policy = channel_info.get_policy_for_node(start_node) if channel_policy is None: return float('inf') cltv_expiry_delta = channel_policy.cltv_expiry_delta htlc_minimum_msat = channel_policy.htlc_minimum_msat fee_base_msat = channel_policy.fee_base_msat fee_proportional_millionths = channel_policy.fee_proportional_millionths if payment_amt_msat is not None: if payment_amt_msat < htlc_minimum_msat: return float('inf') # payment amount too little if channel_info.capacity_sat is not None and \ payment_amt_msat // 1000 > channel_info.capacity_sat: return float('inf') # payment amount too large amt = payment_amt_msat or 50000 * 1000 # guess for typical payment amount fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000 # TODO revise # paying 10 more satoshis ~ waiting one more block fee_cost = fee_msat / 1000 / 10 cltv_cost = cltv_expiry_delta if not ignore_cltv else 0 return cltv_cost + fee_cost + 1 @profiler def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes, amount_msat: int=None) -> Sequence[Tuple[bytes, bytes]]: """Return a path between from_node_id and to_node_id. Returns a list of (node_id, short_channel_id) representing a path. To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; i.e. an element reads as, "to get to node_id, travel through short_channel_id" """ if amount_msat is not None: assert type(amount_msat) is int # TODO find multiple paths?? # run Dijkstra distance_from_start = defaultdict(lambda: float('inf')) distance_from_start[from_node_id] = 0 prev_node = {} nodes_to_explore = queue.PriorityQueue() nodes_to_explore.put((0, from_node_id)) while nodes_to_explore.qsize() > 0: dist_to_cur_node, cur_node = nodes_to_explore.get() if cur_node == to_node_id: break if dist_to_cur_node != distance_from_start[cur_node]: # queue.PriorityQueue does not implement decrease_priority, # so instead of decreasing priorities, we add items again into the queue. # so there are duplicates in the queue, that we discard now: continue for edge_channel_id in self.channel_db.get_channels_for_node(cur_node): if edge_channel_id in self.blacklist: continue channel_info = self.channel_db.get_channel_info(edge_channel_id) node1, node2 = channel_info.node_id_1, channel_info.node_id_2 neighbour = node2 if node1 == cur_node else node1 ignore_cltv_delta_in_edge_cost = cur_node == from_node_id edge_cost = self._edge_cost(edge_channel_id, cur_node, amount_msat, ignore_cltv=ignore_cltv_delta_in_edge_cost) alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost if alt_dist_to_neighbour < distance_from_start[neighbour]: distance_from_start[neighbour] = alt_dist_to_neighbour prev_node[neighbour] = cur_node, edge_channel_id nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) else: return None # no path found # backtrack from end to start cur_node = to_node_id path = [] while cur_node != from_node_id: prev_node_id, edge_taken = prev_node[cur_node] path += [(cur_node, edge_taken)] cur_node = prev_node_id path.reverse() return path def create_route_from_path(self, path, from_node_id: bytes) -> Sequence[RouteEdge]: assert type(from_node_id) is bytes if path is None: raise Exception('cannot create route from None path') route = [] prev_node_id = from_node_id for node_id, short_channel_id in path: channel_info = self.channel_db.get_channel_info(short_channel_id) if channel_info is None: raise Exception('cannot find channel info for short_channel_id: {}'.format(bh2u(short_channel_id))) channel_policy = channel_info.get_policy_for_node(prev_node_id) if channel_policy is None: raise Exception('cannot find channel policy for short_channel_id: {}'.format(bh2u(short_channel_id))) route.append(RouteEdge(node_id, short_channel_id, channel_policy)) prev_node_id = node_id return route # bolt 04, "onion" -----> NUM_MAX_HOPS_IN_PATH = 20 HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04 PER_HOP_FULL_SIZE = 65 # HOPS_DATA_SIZE / 20 NUM_STREAM_BYTES = HOPS_DATA_SIZE + PER_HOP_FULL_SIZE PER_HOP_HMAC_SIZE = 32 class UnsupportedOnionPacketVersion(Exception): pass class InvalidOnionMac(Exception): pass class OnionPerHop: def __init__(self, short_channel_id: bytes, amt_to_forward: bytes, outgoing_cltv_value: bytes): self.short_channel_id = short_channel_id self.amt_to_forward = amt_to_forward self.outgoing_cltv_value = outgoing_cltv_value def to_bytes(self) -> bytes: ret = self.short_channel_id ret += self.amt_to_forward ret += self.outgoing_cltv_value ret += bytes(12) # padding if len(ret) != 32: raise Exception('unexpected length {}'.format(len(ret))) return ret @classmethod def from_bytes(cls, b: bytes): if len(b) != 32: raise Exception('unexpected length {}'.format(len(b))) return OnionPerHop( short_channel_id=b[:8], amt_to_forward=b[8:16], outgoing_cltv_value=b[16:20] ) class OnionHopsDataSingle: # called HopData in lnd def __init__(self, per_hop: OnionPerHop = None): self.realm = 0 self.per_hop = per_hop self.hmac = None def to_bytes(self) -> bytes: ret = bytes([self.realm]) ret += self.per_hop.to_bytes() ret += self.hmac if self.hmac is not None else bytes(PER_HOP_HMAC_SIZE) if len(ret) != PER_HOP_FULL_SIZE: raise Exception('unexpected length {}'.format(len(ret))) return ret @classmethod def from_bytes(cls, b: bytes): if len(b) != PER_HOP_FULL_SIZE: raise Exception('unexpected length {}'.format(len(b))) ret = OnionHopsDataSingle() ret.realm = b[0] if ret.realm != 0: raise Exception('only realm 0 is supported') ret.per_hop = OnionPerHop.from_bytes(b[1:33]) ret.hmac = b[33:] return ret class OnionPacket: def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes): self.version = 0 self.public_key = public_key self.hops_data = hops_data # also called RoutingInfo in bolt-04 self.hmac = hmac def to_bytes(self) -> bytes: ret = bytes([self.version]) ret += self.public_key ret += self.hops_data ret += self.hmac if len(ret) != 1366: raise Exception('unexpected length {}'.format(len(ret))) return ret @classmethod def from_bytes(cls, b: bytes): if len(b) != 1366: raise Exception('unexpected length {}'.format(len(b))) version = b[0] if version != 0: raise UnsupportedOnionPacketVersion('version {} is not supported'.format(version)) return OnionPacket( public_key=b[1:34], hops_data=b[34:1334], hmac=b[1334:] ) def get_bolt04_onion_key(key_type: bytes, secret: bytes) -> bytes: if key_type not in (b'rho', b'mu', b'um', b'ammag'): raise Exception('invalid key_type {}'.format(key_type)) key = hmac.new(key_type, msg=secret, digestmod=hashlib.sha256).digest() return key def get_shared_secrets_along_route(payment_path_pubkeys: Sequence[bytes], session_key: bytes) -> Sequence[bytes]: num_hops = len(payment_path_pubkeys) hop_shared_secrets = num_hops * [b''] ephemeral_key = session_key # compute shared key for each hop for i in range(0, num_hops): hop_shared_secrets[i] = lnbase.get_ecdh(ephemeral_key, payment_path_pubkeys[i]) ephemeral_pubkey = ecc.ECPrivkey(ephemeral_key).get_public_key_bytes() blinding_factor = sha256(ephemeral_pubkey + hop_shared_secrets[i]) blinding_factor_int = int.from_bytes(blinding_factor, byteorder="big") ephemeral_key_int = int.from_bytes(ephemeral_key, byteorder="big") ephemeral_key_int = ephemeral_key_int * blinding_factor_int % ecc.CURVE_ORDER ephemeral_key = ephemeral_key_int.to_bytes(32, byteorder="big") return hop_shared_secrets def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes, hops_data: Sequence[OnionHopsDataSingle], associated_data: bytes) -> OnionPacket: num_hops = len(payment_path_pubkeys) hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key) filler = generate_filler(b'rho', num_hops, PER_HOP_FULL_SIZE, hop_shared_secrets) mix_header = bytes(HOPS_DATA_SIZE) next_hmac = bytes(PER_HOP_HMAC_SIZE) # compute routing info and MAC for each hop for i in range(num_hops-1, -1, -1): rho_key = get_bolt04_onion_key(b'rho', hop_shared_secrets[i]) mu_key = get_bolt04_onion_key(b'mu', hop_shared_secrets[i]) hops_data[i].hmac = next_hmac stream_bytes = generate_cipher_stream(rho_key, NUM_STREAM_BYTES) mix_header = mix_header[:-PER_HOP_FULL_SIZE] mix_header = hops_data[i].to_bytes() + mix_header mix_header = xor_bytes(mix_header, stream_bytes) if i == num_hops - 1 and len(filler) != 0: mix_header = mix_header[:-len(filler)] + filler packet = mix_header + associated_data next_hmac = hmac.new(mu_key, msg=packet, digestmod=hashlib.sha256).digest() return OnionPacket( public_key=ecc.ECPrivkey(session_key).get_public_key_bytes(), hops_data=mix_header, hmac=next_hmac) def generate_filler(key_type: bytes, num_hops: int, hop_size: int, shared_secrets: Sequence[bytes]) -> bytes: filler_size = (NUM_MAX_HOPS_IN_PATH + 1) * hop_size filler = bytearray(filler_size) for i in range(0, num_hops-1): # -1, as last hop does not obfuscate filler = filler[hop_size:] filler += bytearray(hop_size) stream_key = get_bolt04_onion_key(key_type, shared_secrets[i]) stream_bytes = generate_cipher_stream(stream_key, filler_size) filler = xor_bytes(filler, stream_bytes) return filler[(NUM_MAX_HOPS_IN_PATH-num_hops+2)*hop_size:] def generate_cipher_stream(stream_key: bytes, num_bytes: int) -> bytes: algo = algorithms.ChaCha20(stream_key, nonce=bytes(16)) cipher = Cipher(algo, mode=None, backend=default_backend()) encryptor = cipher.encryptor() return encryptor.update(bytes(num_bytes)) ProcessedOnionPacket = namedtuple("ProcessedOnionPacket", ["are_we_final", "hop_data", "next_packet"]) # TODO replay protection def process_onion_packet(onion_packet: OnionPacket, associated_data: bytes, our_onion_private_key: bytes) -> ProcessedOnionPacket: shared_secret = lnbase.get_ecdh(our_onion_private_key, onion_packet.public_key) # check message integrity mu_key = get_bolt04_onion_key(b'mu', shared_secret) calculated_mac = hmac.new(mu_key, msg=onion_packet.hops_data+associated_data, digestmod=hashlib.sha256).digest() if onion_packet.hmac != calculated_mac: raise InvalidOnionMac() # peel an onion layer off rho_key = get_bolt04_onion_key(b'rho', shared_secret) stream_bytes = generate_cipher_stream(rho_key, NUM_STREAM_BYTES) padded_header = onion_packet.hops_data + bytes(PER_HOP_FULL_SIZE) next_hops_data = xor_bytes(padded_header, stream_bytes) # calc next ephemeral key blinding_factor = sha256(onion_packet.public_key + shared_secret) blinding_factor_int = int.from_bytes(blinding_factor, byteorder="big") next_public_key_int = ecc.ECPubkey(onion_packet.public_key) * blinding_factor_int next_public_key = next_public_key_int.get_public_key_bytes() hop_data = OnionHopsDataSingle.from_bytes(next_hops_data[:PER_HOP_FULL_SIZE]) next_onion_packet = OnionPacket( public_key=next_public_key, hops_data=next_hops_data[PER_HOP_FULL_SIZE:], hmac=hop_data.hmac ) if hop_data.hmac == bytes(PER_HOP_HMAC_SIZE): # we are the destination / exit node are_we_final = True else: # we are an intermediate node; forwarding are_we_final = False return ProcessedOnionPacket(are_we_final, hop_data, next_onion_packet) class FailedToDecodeOnionError(Exception): pass class OnionRoutingFailureMessage: def __init__(self, code: int, data: bytes): self.code = code self.data = data def __repr__(self): return repr((self.code, self.data)) def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes], session_key: bytes) -> (bytes, int): """Returns the decoded error bytes, and the index of the sender of the error.""" num_hops = len(payment_path_pubkeys) hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key) for i in range(num_hops): ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i]) um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i]) stream_bytes = generate_cipher_stream(ammag_key, len(error_packet)) error_packet = xor_bytes(error_packet, stream_bytes) hmac_computed = hmac.new(um_key, msg=error_packet[32:], digestmod=hashlib.sha256).digest() hmac_found = error_packet[:32] if hmac_computed == hmac_found: return error_packet, i raise FailedToDecodeOnionError() def decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes], session_key: bytes) -> (OnionRoutingFailureMessage, int): """Returns the failure message, and the index of the sender of the error.""" decrypted_error, sender_index = _decode_onion_error(error_packet, payment_path_pubkeys, session_key) failure_msg = get_failure_msg_from_onion_error(decrypted_error) return failure_msg, sender_index def get_failure_msg_from_onion_error(decrypted_error_packet: bytes) -> OnionRoutingFailureMessage: # get failure_msg bytes from error packet failure_len = int.from_bytes(decrypted_error_packet[32:34], byteorder='big') failure_msg = decrypted_error_packet[34:34+failure_len] # create failure message object failure_code = int.from_bytes(failure_msg[:2], byteorder='big') failure_data = failure_msg[2:] return OnionRoutingFailureMessage(failure_code, failure_data) # <----- bolt 04, "onion"