Browse Source

lnworker/lnpeer: add some type hints, force some kwargs

patch-4
SomberNight 4 years ago
parent
commit
691ebaf4f8
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 9
      electrum/lnonion.py
  2. 59
      electrum/lnpeer.py
  3. 5
      electrum/lnrater.py
  4. 150
      electrum/lnworker.py
  5. 8
      electrum/tests/test_lnpeer.py

9
electrum/lnonion.py

@ -437,9 +437,12 @@ class OnionRoutingFailure(Exception):
return str(self.code.name)
return f"Unknown error ({self.code!r})"
def construct_onion_error(reason: OnionRoutingFailure,
onion_packet: OnionPacket,
our_onion_private_key: bytes) -> bytes:
def construct_onion_error(
reason: OnionRoutingFailure,
onion_packet: OnionPacket,
our_onion_private_key: bytes,
) -> bytes:
# create payload
failure_msg = reason.to_bytes()
failure_len = len(failure_msg)

59
electrum/lnpeer.py

@ -1373,9 +1373,12 @@ class Peer(Logger):
chan.receive_htlc(htlc, onion_packet)
util.trigger_callback('htlc_added', chan, htlc, RECEIVED)
def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket
) -> Tuple[Optional[bytes], Optional[int], Optional[OnionRoutingFailure]]:
def maybe_forward_htlc(
self,
*,
htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket,
) -> Tuple[bytes, int]:
# Forward HTLC
# FIXME: there are critical safety checks MISSING here
forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
@ -1662,7 +1665,7 @@ class Peer(Logger):
self.shutdown_received[chan_id] = asyncio.Future()
await self.send_shutdown(chan)
payload = await self.shutdown_received[chan_id]
txid = await self._shutdown(chan, payload, True)
txid = await self._shutdown(chan, payload, is_local=True)
self.logger.info(f'({chan.get_id_for_log()}) Channel closed {txid}')
return txid
@ -1686,10 +1689,10 @@ class Peer(Logger):
else:
chan = self.channels[chan_id]
await self.send_shutdown(chan)
txid = await self._shutdown(chan, payload, False)
txid = await self._shutdown(chan, payload, is_local=False)
self.logger.info(f'({chan.get_id_for_log()}) Channel closed by remote peer {txid}')
def can_send_shutdown(self, chan):
def can_send_shutdown(self, chan: Channel):
if chan.get_state() >= ChannelState.OPENING:
return True
if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent:
@ -1718,7 +1721,7 @@ class Peer(Logger):
chan.set_can_send_ctx_updates(True)
@log_exceptions
async def _shutdown(self, chan: Channel, payload, is_local):
async def _shutdown(self, chan: Channel, payload, *, is_local: bool):
# wait until no HTLCs remain in either commitment transaction
while len(chan.hm.htlcs(LOCAL)) + len(chan.hm.htlcs(REMOTE)) > 0:
self.logger.info(f'(chan: {chan.short_channel_id}) waiting for htlcs to settle...')
@ -1826,7 +1829,12 @@ class Peer(Logger):
error_reason = e
else:
try:
preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet)
preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(
chan=chan,
htlc=htlc,
forwarding_info=forwarding_info,
onion_packet_bytes=onion_packet_bytes,
onion_packet=onion_packet)
except OnionRoutingFailure as e:
error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey)
if fw_info:
@ -1850,13 +1858,24 @@ class Peer(Logger):
for htlc_id in done:
unfulfilled.pop(htlc_id)
def process_unfulfilled_htlc(self, chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet):
def process_unfulfilled_htlc(
self,
*,
chan: Channel,
htlc: UpdateAddHtlc,
forwarding_info: Tuple[str, int],
onion_packet_bytes: bytes,
onion_packet: OnionPacket,
) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]:
"""
returns either preimage or fw_info or error_bytes or (None, None, None)
raise an OnionRoutingFailure if we need to fail the htlc
"""
payment_hash = htlc.payment_hash
processed_onion = self.process_onion_packet(onion_packet, payment_hash, onion_packet_bytes)
processed_onion = self.process_onion_packet(
onion_packet,
payment_hash=payment_hash,
onion_packet_bytes=onion_packet_bytes)
if processed_onion.are_we_final:
preimage = self.maybe_fulfill_htlc(
chan=chan,
@ -1867,8 +1886,8 @@ class Peer(Logger):
if not forwarding_info:
trampoline_onion = self.process_onion_packet(
processed_onion.trampoline_onion_packet,
htlc.payment_hash,
onion_packet_bytes,
payment_hash=htlc.payment_hash,
onion_packet_bytes=onion_packet_bytes,
is_trampoline=True)
if trampoline_onion.are_we_final:
preimage = self.maybe_fulfill_htlc(
@ -1892,13 +1911,10 @@ class Peer(Logger):
elif not forwarding_info:
next_chan_id, next_htlc_id = self.maybe_forward_htlc(
chan=chan,
htlc=htlc,
onion_packet=onion_packet,
processed_onion=processed_onion)
if next_chan_id:
fw_info = (next_chan_id.hex(), next_htlc_id)
return None, fw_info, None
fw_info = (next_chan_id.hex(), next_htlc_id)
return None, fw_info, None
else:
preimage = self.lnworker.get_preimage(payment_hash)
next_chan_id_hex, htlc_id = forwarding_info
@ -1913,7 +1929,14 @@ class Peer(Logger):
return preimage, None, None
return None, None, None
def process_onion_packet(self, onion_packet, payment_hash, onion_packet_bytes, is_trampoline=False):
def process_onion_packet(
self,
onion_packet: OnionPacket,
*,
payment_hash: bytes,
onion_packet_bytes: bytes,
is_trampoline: bool = False,
) -> ProcessedOnionPacket:
failure_data = sha256(onion_packet_bytes)
try:
processed_onion = process_onion_packet(

5
electrum/lnrater.py

@ -268,7 +268,10 @@ class LNRater(Logger):
return pk, self._node_stats[pk]
def suggest_peer(self):
def suggest_peer(self) -> Optional[bytes]:
"""Suggests a LN node to open a channel with.
Returns a node ID (pubkey).
"""
self.maybe_analyze_graph()
if self._node_ratings:
return self.suggest_node_channel_open()[0]

150
electrum/lnworker.py

@ -7,7 +7,8 @@ import os
from decimal import Decimal
import random
import time
from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
NamedTuple, Union, Mapping, Any, Iterable)
import threading
import socket
import aiohttp
@ -266,10 +267,10 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
with self.lock:
return self._peers.copy()
def channels_for_peer(self, node_id):
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
return {}
def get_node_alias(self, node_id):
def get_node_alias(self, node_id: bytes) -> str:
if self.channel_db:
node_info = self.channel_db.get_node_info_for_node_id(node_id)
node_alias = (node_info.alias if node_info else '') or node_id.hex()
@ -380,7 +381,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self._add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
def is_good_peer(self, peer):
def is_good_peer(self, peer: LNPeerAddr) -> bool:
# the purpose of this method is to filter peers that advertise the desired feature bits
# it is disabled for now, because feature bits published in node announcements seem to be unreliable
return True
@ -566,7 +567,7 @@ class LNGossip(LNWorker):
self.channel_db.prune_orphaned_channels()
await asyncio.sleep(120)
async def add_new_ids(self, ids):
async def add_new_ids(self, ids: Iterable[bytes]):
known = self.channel_db.get_channel_ids()
new = set(ids) - set(known)
self.unknown_ids.update(new)
@ -574,7 +575,7 @@ class LNGossip(LNWorker):
util.trigger_callback('gossip_peers', self.num_peers())
util.trigger_callback('ln_gossip_sync_progress')
def get_ids_to_query(self):
def get_ids_to_query(self) -> Sequence[bytes]:
N = 500
l = list(self.unknown_ids)
self.unknown_ids = set(l[N:])
@ -910,7 +911,7 @@ class LNWallet(LNWorker):
if chan.funding_outpoint.to_str() == txo:
return chan
async def on_channel_update(self, chan):
async def on_channel_update(self, chan: Channel):
if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()):
self.logger.info(f"force-closing due to expiring htlcs")
@ -938,10 +939,14 @@ class LNWallet(LNWorker):
@log_exceptions
async def _open_channel_coroutine(
self, *, connect_str: str,
self,
*,
connect_str: str,
funding_tx: PartialTransaction,
funding_sat: int, push_sat: int,
password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
funding_sat: int,
push_sat: int,
password: Optional[str],
) -> Tuple[Channel, PartialTransaction]:
peer = await self.add_peer(connect_str)
coro = peer.channel_establishment_flow(
funding_tx=funding_tx,
@ -1006,7 +1011,7 @@ class LNWallet(LNWorker):
if chan.short_channel_id == short_channel_id:
return chan
def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None):
def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
return self.create_routes_for_payment(
amount_msat=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
@ -1051,9 +1056,16 @@ class LNWallet(LNWorker):
util.trigger_callback('invoice_status', self.wallet, key)
try:
await self.pay_to_node(
invoice_pubkey, payment_hash, payment_secret, amount_to_pay,
min_cltv_expiry, r_tags, t_tags, invoice_features,
attempts=attempts, full_path=full_path)
node_pubkey=invoice_pubkey,
payment_hash=payment_hash,
payment_secret=payment_secret,
amount_to_pay=amount_to_pay,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
t_tags=t_tags,
invoice_features=invoice_features,
attempts=attempts,
full_path=full_path)
success = True
except PaymentFailure as e:
self.logger.exception('')
@ -1068,12 +1080,23 @@ class LNWallet(LNWorker):
log = self.logs[key]
return success, log
async def pay_to_node(
self, node_pubkey, payment_hash, payment_secret, amount_to_pay,
min_cltv_expiry, r_tags, t_tags, invoice_features, *,
attempts: int = 1, full_path: LNPaymentPath=None,
trampoline_onion=None, trampoline_fee=None, trampoline_cltv_delta=None):
self,
*,
node_pubkey: bytes,
payment_hash: bytes,
payment_secret: Optional[bytes],
amount_to_pay: int, # in msat
min_cltv_expiry: int,
r_tags,
t_tags,
invoice_features: int,
attempts: int = 1,
full_path: LNPaymentPath = None,
trampoline_onion=None,
trampoline_fee=None,
trampoline_cltv_delta=None,
) -> None:
if trampoline_onion:
# todo: compare to the fee of the actual route we found
@ -1095,7 +1118,14 @@ class LNWallet(LNWorker):
min_cltv_expiry, r_tags, t_tags, invoice_features, full_path=full_path))
# 2. send htlcs
for route, amount_msat in routes:
await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion)
await self.pay_to_route(
route,
amount_msat=amount_msat,
total_msat=amount_to_pay,
payment_hash=payment_hash,
payment_secret=payment_secret,
min_cltv_expiry=min_cltv_expiry,
trampoline_onion=trampoline_onion)
amount_inflight += amount_msat
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
# 3. await a queue
@ -1111,9 +1141,17 @@ class LNWallet(LNWorker):
# if we get a channel update, we might retry the same route and amount
self.handle_error_code_from_failed_htlc(htlc_log)
async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int,
total_msat: int, payment_hash: bytes, payment_secret: bytes,
min_cltv_expiry: int, trampoline_onion: bytes=None):
async def pay_to_route(
self,
route: LNPaymentRoute,
*,
amount_msat: int,
total_msat: int,
payment_hash: bytes,
payment_secret: Optional[bytes],
min_cltv_expiry: int,
trampoline_onion: bytes = None,
) -> None:
# send a single htlc
short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
@ -1267,7 +1305,7 @@ class LNWallet(LNWorker):
result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
return result.tobytes()
def is_trampoline_peer(self, node_id):
def is_trampoline_peer(self, node_id: bytes) -> bool:
# until trampoline is advertised in lnfeatures, check against hardcoded list
if is_hardcoded_trampoline(node_id):
return True
@ -1276,8 +1314,11 @@ class LNWallet(LNWorker):
return True
return False
def suggest_peer(self):
return self.lnrater.suggest_peer() if self.channel_db else random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
def suggest_peer(self) -> Optional[bytes]:
if self.channel_db:
return self.lnrater.suggest_peer()
else:
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
def create_trampoline_route(
self, amount_msat:int,
@ -1400,8 +1441,10 @@ class LNWallet(LNWorker):
invoice_pubkey,
min_cltv_expiry,
r_tags, t_tags,
invoice_features,
*, full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
invoice_features: int,
*,
full_path: LNPaymentPath = None,
) -> Sequence[Tuple[LNPaymentRoute, int]]:
"""Creates multiple routes for splitting a payment over the available
private channels.
@ -1411,13 +1454,14 @@ class LNWallet(LNWorker):
# try to send over a single channel
try:
routes = [self.create_route_for_payment(
amount_msat,
invoice_pubkey,
min_cltv_expiry,
r_tags, t_tags,
invoice_features,
None,
full_path=full_path
amount_msat=amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
t_tags=t_tags,
invoice_features=invoice_features,
outgoing_channel=None,
full_path=full_path,
)]
except NoPathFound:
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
@ -1439,12 +1483,13 @@ class LNWallet(LNWorker):
# its capacity. This could be dealt with by temporarily
# iteratively blacklisting channels for this mpp attempt.
route, amt = self.create_route_for_payment(
part_amount_msat,
invoice_pubkey,
min_cltv_expiry,
r_tags, t_tags,
invoice_features,
channel,
amount_msat=part_amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
t_tags=t_tags,
invoice_features=invoice_features,
outgoing_channel=channel,
full_path=None)
routes.append((route, amt))
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
@ -1457,13 +1502,16 @@ class LNWallet(LNWorker):
def create_route_for_payment(
self,
*,
amount_msat: int,
invoice_pubkey,
min_cltv_expiry,
r_tags, t_tags,
invoice_features,
invoice_pubkey: bytes,
min_cltv_expiry: int,
r_tags,
t_tags,
invoice_features: int,
outgoing_channel: Channel = None,
*, full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
full_path: Optional[LNPaymentPath],
) -> Tuple[LNPaymentRoute, int]:
channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
if not self.channel_db:
@ -1554,7 +1602,13 @@ class LNWallet(LNWorker):
raise Exception(_("add invoice timed out"))
@log_exceptions
async def create_invoice(self, *, amount_msat: Optional[int], message, expiry: int):
async def create_invoice(
self,
*,
amount_msat: Optional[int],
message,
expiry: int,
) -> Tuple[LnAddr, str]:
timestamp = int(time.time())
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
if not routing_hints:
@ -1628,7 +1682,7 @@ class LNWallet(LNWorker):
self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db()
def htlc_received(self, short_channel_id, htlc, expected_msat):
def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int):
status = self.get_payment_status(htlc.payment_hash)
if status == PR_PAID:
return True, None

8
electrum/tests/test_lnpeer.py

@ -775,7 +775,13 @@ class TestPeer(ElectrumTestCase):
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
payment_hash = lnaddr.paymenthash
payment_secret = lnaddr.payment_secret
pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
pay = w1.pay_to_route(
route,
amount_msat=amount_msat,
total_msat=amount_msat,
payment_hash=payment_hash,
payment_secret=payment_secret,
min_cltv_expiry=min_cltv_expiry)
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(PaymentFailure):
run(f())

Loading…
Cancel
Save