Browse Source

protect against getting robbed through routing fees

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 6 years ago
committed by ThomasV
parent
commit
2fafd01945
  1. 10
      electrum/lnonion.py
  2. 91
      electrum/lnrouter.py
  3. 9
      electrum/lnutil.py
  4. 18
      electrum/lnworker.py

10
electrum/lnonion.py

@ -33,11 +33,10 @@ from cryptography.hazmat.backends import default_backend
from . import ecc from . import ecc
from .crypto import sha256, hmac_oneshot from .crypto import sha256, hmac_oneshot
from .util import bh2u, profiler, xor_bytes, bfh from .util import bh2u, profiler, xor_bytes, bfh
from .lnutil import get_ecdh from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH
from .lnrouter import RouteEdge from .lnrouter import RouteEdge
NUM_MAX_HOPS_IN_PATH = 20
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04 HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
PER_HOP_FULL_SIZE = 65 # HOPS_DATA_SIZE / 20 PER_HOP_FULL_SIZE = 65 # HOPS_DATA_SIZE / 20
NUM_STREAM_BYTES = HOPS_DATA_SIZE + PER_HOP_FULL_SIZE NUM_STREAM_BYTES = HOPS_DATA_SIZE + PER_HOP_FULL_SIZE
@ -192,6 +191,9 @@ def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_c
"""Returns the hops_data to be used for constructing an onion packet, """Returns the hops_data to be used for constructing an onion packet,
and the amount_msat and cltv to be used on our immediate channel. and the amount_msat and cltv to be used on our immediate channel.
""" """
if len(route) > NUM_MAX_HOPS_IN_PAYMENT_PATH:
raise PaymentFailure(f"too long route ({len(route)} hops)")
amt = amount_msat amt = amount_msat
cltv = final_cltv cltv = final_cltv
hops_data = [OnionHopsDataSingle(OnionPerHop(b"\x00" * 8, hops_data = [OnionHopsDataSingle(OnionPerHop(b"\x00" * 8,
@ -209,7 +211,7 @@ def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_c
def generate_filler(key_type: bytes, num_hops: int, hop_size: int, def generate_filler(key_type: bytes, num_hops: int, hop_size: int,
shared_secrets: Sequence[bytes]) -> bytes: shared_secrets: Sequence[bytes]) -> bytes:
filler_size = (NUM_MAX_HOPS_IN_PATH + 1) * hop_size filler_size = (NUM_MAX_HOPS_IN_PAYMENT_PATH + 1) * hop_size
filler = bytearray(filler_size) filler = bytearray(filler_size)
for i in range(0, num_hops-1): # -1, as last hop does not obfuscate for i in range(0, num_hops-1): # -1, as last hop does not obfuscate
@ -219,7 +221,7 @@ def generate_filler(key_type: bytes, num_hops: int, hop_size: int,
stream_bytes = generate_cipher_stream(stream_key, filler_size) stream_bytes = generate_cipher_stream(stream_key, filler_size)
filler = xor_bytes(filler, stream_bytes) filler = xor_bytes(filler, stream_bytes)
return filler[(NUM_MAX_HOPS_IN_PATH-num_hops+2)*hop_size:] return filler[(NUM_MAX_HOPS_IN_PAYMENT_PATH-num_hops+2)*hop_size:]
def generate_cipher_stream(stream_key: bytes, num_bytes: int) -> bytes: def generate_cipher_stream(stream_key: bytes, num_bytes: int) -> bytes:

91
electrum/lnrouter.py

@ -39,7 +39,7 @@ from .storage import JsonDB
from .lnchannelverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnchannelverifier import LNChannelVerifier, verify_sig_for_channel_update
from .crypto import Hash from .crypto import Hash
from . import ecc from . import ecc
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_HOPS_IN_PAYMENT_PATH
class UnknownEvenFeatureBits(Exception): pass class UnknownEvenFeatureBits(Exception): pass
@ -502,10 +502,61 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
('cltv_expiry_delta', int)])): ('cltv_expiry_delta', int)])):
"""if you travel through short_channel_id, you will reach node_id""" """if you travel through short_channel_id, you will reach node_id"""
def fee_for_edge(self, amount_msat): def fee_for_edge(self, amount_msat: int) -> int:
return self.fee_base_msat \ return self.fee_base_msat \
+ (amount_msat * self.fee_proportional_millionths // 1_000_000) + (amount_msat * self.fee_proportional_millionths // 1_000_000)
@classmethod
def from_channel_policy(cls, channel_policy: ChannelInfoDirectedPolicy,
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
return RouteEdge(end_node,
short_channel_id,
channel_policy.fee_base_msat,
channel_policy.fee_proportional_millionths,
channel_policy.cltv_expiry_delta)
def is_sane_to_use(self, amount_msat: int) -> bool:
# TODO revise ad-hoc heuristics
# cltv cannot be more than 2 weeks
if self.cltv_expiry_delta > 14 * 144: return False
total_fee = self.fee_for_edge(amount_msat)
# fees below 50 sat are fine
if total_fee > 50_000:
# fee cannot be higher than amt
if total_fee > amount_msat: return False
# fee cannot be higher than 5000 sat
if total_fee > 5_000_000: return False
# unless amt is tiny, fee cannot be more than 10%
if amount_msat > 1_000_000 and total_fee > amount_msat/10: return False
return True
def is_route_sane_to_use(route: List[RouteEdge], invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
"""Run some sanity checks on the whole route, before attempting to use it.
called when we are paying; so e.g. lower cltv is better
"""
if len(route) > NUM_MAX_HOPS_IN_PAYMENT_PATH:
return False
amt = invoice_amount_msat
cltv = min_final_cltv_expiry
for route_edge in reversed(route[1:]):
if not route_edge.is_sane_to_use(amt): return False
amt += route_edge.fee_for_edge(amt)
cltv += route_edge.cltv_expiry_delta
total_fee = amt - invoice_amount_msat
# TODO revise ad-hoc heuristics
# cltv cannot be more than 2 months
if cltv > 60 * 144: return False
# fees below 50 sat are fine
if total_fee > 50_000:
# fee cannot be higher than amt
if total_fee > invoice_amount_msat: return False
# fee cannot be higher than 5000 sat
if total_fee > 5_000_000: return False
# unless amt is tiny, fee cannot be more than 10%
if invoice_amount_msat > 1_000_000 and total_fee > invoice_amount_msat/10: return False
return True
class LNPathFinder(PrintError): class LNPathFinder(PrintError):
@ -513,11 +564,9 @@ class LNPathFinder(PrintError):
self.channel_db = channel_db self.channel_db = channel_db
self.blacklist = set() self.blacklist = set()
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, payment_amt_msat: int, def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
ignore_cltv=False) -> float: payment_amt_msat: int, ignore_cltv=False) -> float:
"""Heuristic cost of going through a channel. """Heuristic cost of going through a channel."""
direction: 0 or 1. --- 0 means node_id_1 -> node_id_2
"""
channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo
if channel_info is None: if channel_info is None:
return float('inf') return float('inf')
@ -525,12 +574,8 @@ class LNPathFinder(PrintError):
channel_policy = channel_info.get_policy_for_node(start_node) channel_policy = channel_info.get_policy_for_node(start_node)
if channel_policy is None: return float('inf') if channel_policy is None: return float('inf')
if channel_policy.disabled: return float('inf') if channel_policy.disabled: return float('inf')
cltv_expiry_delta = channel_policy.cltv_expiry_delta route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
htlc_minimum_msat = channel_policy.htlc_minimum_msat if payment_amt_msat < channel_policy.htlc_minimum_msat:
fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths
if payment_amt_msat is not None:
if payment_amt_msat < htlc_minimum_msat:
return float('inf') # payment amount too little return float('inf') # payment amount too little
if channel_info.capacity_sat is not None and \ if channel_info.capacity_sat is not None and \
payment_amt_msat // 1000 > channel_info.capacity_sat: payment_amt_msat // 1000 > channel_info.capacity_sat:
@ -538,28 +583,30 @@ class LNPathFinder(PrintError):
if channel_policy.htlc_maximum_msat is not None and \ if channel_policy.htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy.htlc_maximum_msat: payment_amt_msat > channel_policy.htlc_maximum_msat:
return float('inf') # payment amount too large return float('inf') # payment amount too large
amt = payment_amt_msat or 50000 * 1000 # guess for typical payment amount if not route_edge.is_sane_to_use(payment_amt_msat):
fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1_000_000 return float('inf') # thanks but no thanks
fee_msat = route_edge.fee_for_edge(payment_amt_msat)
# TODO revise # TODO revise
# paying 10 more satoshis ~ waiting one more block # paying 10 more satoshis ~ waiting one more block
fee_cost = fee_msat / 1000 / 10 fee_cost = fee_msat / 1000 / 10
cltv_cost = cltv_expiry_delta if not ignore_cltv else 0 cltv_cost = route_edge.cltv_expiry_delta if not ignore_cltv else 0
return cltv_cost + fee_cost + 1 return cltv_cost + fee_cost + 1
@profiler @profiler
def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes, def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes,
amount_msat: int=None, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]: amount_msat: int, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
"""Return a path between from_node_id and to_node_id. """Return a path between from_node_id and to_node_id.
Returns a list of (node_id, short_channel_id) representing a path. Returns a list of (node_id, short_channel_id) representing a path.
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
i.e. an element reads as, "to get to node_id, travel through short_channel_id" i.e. an element reads as, "to get to node_id, travel through short_channel_id"
""" """
if amount_msat is not None: assert type(amount_msat) is int assert type(amount_msat) is int
if my_channels is None: my_channels = [] if my_channels is None: my_channels = []
unable_channels = set(map(lambda x: x.short_channel_id, filter(lambda x: not x.can_pay(amount_msat), my_channels))) unable_channels = set(map(lambda x: x.short_channel_id, filter(lambda x: not x.can_pay(amount_msat), my_channels)))
# TODO find multiple paths?? # TODO find multiple paths??
# FIXME paths cannot be longer than 20 (onion packet)...
# run Dijkstra # run Dijkstra
distance_from_start = defaultdict(lambda: float('inf')) distance_from_start = defaultdict(lambda: float('inf'))
@ -584,7 +631,7 @@ class LNPathFinder(PrintError):
node1, node2 = channel_info.node_id_1, channel_info.node_id_2 node1, node2 = channel_info.node_id_1, channel_info.node_id_2
neighbour = node2 if node1 == cur_node else node1 neighbour = node2 if node1 == cur_node else node1
ignore_cltv_delta_in_edge_cost = cur_node == from_node_id ignore_cltv_delta_in_edge_cost = cur_node == from_node_id
edge_cost = self._edge_cost(edge_channel_id, cur_node, amount_msat, edge_cost = self._edge_cost(edge_channel_id, cur_node, neighbour, amount_msat,
ignore_cltv=ignore_cltv_delta_in_edge_cost) ignore_cltv=ignore_cltv_delta_in_edge_cost)
alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost
if alt_dist_to_neighbour < distance_from_start[neighbour]: if alt_dist_to_neighbour < distance_from_start[neighbour]:
@ -614,10 +661,6 @@ class LNPathFinder(PrintError):
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
if channel_policy is None: if channel_policy is None:
raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
route.append(RouteEdge(node_id, route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
short_channel_id,
channel_policy.fee_base_msat,
channel_policy.fee_proportional_millionths,
channel_policy.cltv_expiry_delta))
prev_node_id = node_id prev_node_id = node_id
return route return route

9
electrum/lnutil.py

@ -1,7 +1,7 @@
from enum import IntFlag, IntEnum from enum import IntFlag, IntEnum
import json import json
from collections import namedtuple from collections import namedtuple
from typing import NamedTuple, List, Tuple, Mapping from typing import NamedTuple, List, Tuple, Mapping, Optional
import re import re
from .util import bfh, bh2u, inv_dict from .util import bfh, bh2u, inv_dict
@ -16,6 +16,7 @@ from .i18n import _
from .lnaddr import lndecode from .lnaddr import lndecode
from .keystore import BIP32_KeyStore from .keystore import BIP32_KeyStore
HTLC_TIMEOUT_WEIGHT = 663 HTLC_TIMEOUT_WEIGHT = 663
HTLC_SUCCESS_WEIGHT = 703 HTLC_SUCCESS_WEIGHT = 703
@ -597,8 +598,6 @@ def generate_keypair(ln_keystore: BIP32_KeyStore, key_family: LnKeyFamily, index
return Keypair(*ln_keystore.get_keypair([key_family, 0, index], None)) return Keypair(*ln_keystore.get_keypair([key_family, 0, index], None))
from typing import Optional
class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transaction), class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transaction),
('csv_delay', Optional[int])])): ('csv_delay', Optional[int])])):
def to_json(self) -> dict: def to_json(self) -> dict:
@ -612,3 +611,7 @@ class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transact
d2 = dict(d) d2 = dict(d)
d2['tx'] = Transaction(d['tx']) d2['tx'] = Transaction(d['tx'])
return EncumberedTransaction(**d2) return EncumberedTransaction(**d2)
NUM_MAX_HOPS_IN_PAYMENT_PATH = 20

18
electrum/lnworker.py

@ -25,10 +25,11 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid, get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError, PaymentFailure, split_host_port, ConnStringFormatError,
generate_keypair, LnKeyFamily, LOCAL, REMOTE, generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE) UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_HOPS_IN_PAYMENT_PATH)
from .lnaddr import lndecode from .lnaddr import lndecode
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge from .lnrouter import RouteEdge, is_route_sane_to_use
NUM_PEERS_TARGET = 4 NUM_PEERS_TARGET = 4
PEER_RETRY_INTERVAL = 600 # seconds PEER_RETRY_INTERVAL = 600 # seconds
@ -253,6 +254,10 @@ class LNWorker(PrintError):
if amount_sat is None: if amount_sat is None:
raise InvoiceError(_("Missing amount")) raise InvoiceError(_("Missing amount"))
amount_msat = int(amount_sat * 1000) amount_msat = int(amount_sat * 1000)
if addr.get_min_final_cltv_expiry() > 60 * 144:
raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat) route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
peer = self.peers[node_id] peer = self.peers[node_id]
@ -281,6 +286,7 @@ class LNWorker(PrintError):
channels = list(self.channels.values()) channels = list(self.channels.values())
for private_route in r_tags: for private_route in r_tags:
if len(private_route) == 0: continue if len(private_route) == 0: continue
if len(private_route) > NUM_MAX_HOPS_IN_PAYMENT_PATH: continue
border_node_pubkey = private_route[0][0] border_node_pubkey = private_route[0][0]
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels) path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
if not path: continue if not path: continue
@ -301,6 +307,11 @@ class LNWorker(PrintError):
route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths, route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths,
cltv_expiry_delta)) cltv_expiry_delta))
prev_node_id = node_pubkey prev_node_id = node_pubkey
# test sanity
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
self.print_error(f"rejecting insane route {route}")
route = None
continue
break break
# if could not find route using any hint; try without hint now # if could not find route using any hint; try without hint now
if route is None: if route is None:
@ -308,6 +319,9 @@ class LNWorker(PrintError):
if not path: if not path:
raise PaymentFailure(_("No path found")) raise PaymentFailure(_("No path found"))
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
self.print_error(f"rejecting insane route {route}")
raise PaymentFailure(_("No path found"))
return route return route
def add_invoice(self, amount_sat, message): def add_invoice(self, amount_sat, message):

Loading…
Cancel
Save