diff --git a/electrum/lnbase.py b/electrum/lnbase.py index d28775988..881cf0483 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -10,7 +10,7 @@ import asyncio import os import time from functools import partial -from typing import List, Tuple, Dict +from typing import List, Tuple, Dict, TYPE_CHECKING import traceback import sys @@ -31,10 +31,13 @@ from .lnutil import (Outpoint, LocalConfig, ChannelConfig, funding_output_script, get_per_commitment_secret_from_seed, secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures, LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily, - get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED) -from .lnutil import LightningPeerConnectionClosed, HandshakeFailed + get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED, + LightningPeerConnectionClosed, HandshakeFailed, LNPeerAddr) from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge -from .lntransport import LNTransport +from .lntransport import LNTransport, LNTransportBase + +if TYPE_CHECKING: + from .lnworker import LNWorker def channel_id_from_funding_tx(funding_txid, funding_index): @@ -191,7 +194,8 @@ def gen_msg(msg_type: str, **kwargs) -> bytes: class Peer(PrintError): - def __init__(self, lnworker, peer_addr, request_initial_sync=False, transport=None): + def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr, + request_initial_sync=False, transport: LNTransportBase=None): self.initialized = asyncio.Future() self.transport = transport self.peer_addr = peer_addr @@ -357,7 +361,7 @@ class Peer(PrintError): def close_and_cleanup(self): try: if self.transport: - self.transport.writer.close() + self.transport.close() except: pass for chan in self.channels.values(): diff --git a/electrum/lnchan.py b/electrum/lnchan.py index 798d2485b..f9e2160dd 100644 --- a/electrum/lnchan.py +++ b/electrum/lnchan.py @@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict import binascii import json from enum import Enum, auto -from typing import Optional +from typing import Optional, Mapping, List from .util import bfh, PrintError, bh2u from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS diff --git a/electrum/lnchannelverifier.py b/electrum/lnchannelverifier.py index 1397fc072..778174663 100644 --- a/electrum/lnchannelverifier.py +++ b/electrum/lnchannelverifier.py @@ -25,6 +25,7 @@ import asyncio import threading +from typing import TYPE_CHECKING import aiorpcx @@ -38,6 +39,10 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure from .transaction import Transaction from .interface import GracefulDisconnect +if TYPE_CHECKING: + from .network import Network + from .lnrouter import ChannelDB + class LNChannelVerifier(NetworkJobOnDefaultServer): """ Verify channel announcements for the Channel DB """ @@ -46,7 +51,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): # will start throttling us, making it even slower. one option would be to # spread it over multiple servers. - def __init__(self, network, channel_db): + def __init__(self, network: 'Network', channel_db: 'ChannelDB'): NetworkJobOnDefaultServer.__init__(self, network) self.channel_db = channel_db self.lock = threading.Lock() @@ -105,7 +110,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id)) #self.print_error('requested short_channel_id', bh2u(short_channel_id)) - async def verify_channel(self, block_height, tx_pos, short_channel_id): + async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes): # we are verifying channel announcements as they are from untrusted ln peers. # we use electrum servers to do this. however we don't trust electrum servers either... try: diff --git a/electrum/lnonion.py b/electrum/lnonion.py index d87a1da4d..66d8975f5 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -24,7 +24,7 @@ # SOFTWARE. import hashlib -from typing import Sequence, List, Tuple, NamedTuple +from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING from enum import IntEnum, IntFlag from cryptography.hazmat.primitives.ciphers import Cipher, algorithms @@ -34,7 +34,9 @@ from . import ecc from .crypto import sha256, hmac_oneshot from .util import bh2u, profiler, xor_bytes, bfh from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH -from .lnrouter import RouteEdge + +if TYPE_CHECKING: + from .lnrouter import RouteEdge HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04 @@ -186,7 +188,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes, hmac=next_hmac) -def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_cltv: int) \ +def calc_hops_data_for_payment(route: List['RouteEdge'], amount_msat: int, final_cltv: int) \ -> Tuple[List[OnionHopsDataSingle], int, int]: """Returns the hops_data to be used for constructing an onion packet, and the amount_msat and cltv to be used on our immediate channel. diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 36b6134cd..b9e388a1e 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -27,8 +27,8 @@ import queue import os import json import threading -from collections import namedtuple, defaultdict -from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple +from collections import defaultdict +from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING import binascii import base64 import asyncio @@ -41,6 +41,10 @@ from .crypto import Hash from . import ecc from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH +if TYPE_CHECKING: + from .lnchan import Channel + from .network import Network + class UnknownEvenFeatureBits(Exception): pass @@ -272,7 +276,7 @@ class ChannelDB(JsonDB): NUM_MAX_RECENT_PEERS = 20 - def __init__(self, network): + def __init__(self, network: 'Network'): self.network = network path = os.path.join(get_headers_dir(network.config), 'channel_db') @@ -597,7 +601,7 @@ class LNPathFinder(PrintError): @profiler def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, - my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]: + my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]: """Return a path from nodeA to nodeB. Returns a list of (node_id, short_channel_id) representing a path. diff --git a/electrum/lntransport.py b/electrum/lntransport.py index 0386ca2fe..4b291a4da 100644 --- a/electrum/lntransport.py +++ b/electrum/lntransport.py @@ -1,10 +1,11 @@ -import hmac import hashlib +from asyncio import StreamReader, StreamWriter + import cryptography.hazmat.primitives.ciphers.aead as AEAD -from .crypto import sha256 -from .lnutil import get_ecdh, privkey_to_pubkey -from .lnutil import LightningPeerConnectionClosed, HandshakeFailed +from .crypto import sha256, hmac_oneshot +from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed, + HandshakeFailed) from . import ecc from .util import bh2u @@ -49,13 +50,13 @@ def get_bolt8_hkdf(salt, ikm): Return as two 32 byte fields. """ #Extract - prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest() + prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256) assert len(prk) == 32 #Expand info = b"" T0 = b"" - T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest() - T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest() + T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256) + T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256) assert len(T1 + T2) == 64 return T1, T2 @@ -76,6 +77,11 @@ def create_ephemeral_key() -> (bytes, bytes): return privkey.get_secret_bytes(), privkey.get_public_key_bytes() class LNTransportBase: + + def __init__(self, reader: StreamReader, writer: StreamWriter): + self.reader = reader + self.writer = writer + def send_bytes(self, msg): l = len(msg).to_bytes(2, 'big') lc = aead_encrypt(self.sk, self.sn(), b'', l) @@ -132,11 +138,14 @@ class LNTransportBase: self.r_ck = ck self.s_ck = ck + def close(self): + self.writer.close() + + class LNResponderTransport(LNTransportBase): - def __init__(self, privkey, reader, writer): + def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter): + LNTransportBase.__init__(self, reader, writer) self.privkey = privkey - self.reader = reader - self.writer = writer async def handshake(self, **kwargs): hs = HandshakeState(privkey_to_pubkey(self.privkey)) @@ -187,12 +196,12 @@ class LNResponderTransport(LNTransportBase): return rs class LNTransport(LNTransportBase): - def __init__(self, privkey, remote_pubkey, reader, writer): + def __init__(self, privkey: bytes, remote_pubkey: bytes, + reader: StreamReader, writer: StreamWriter): + LNTransportBase.__init__(self, reader, writer) assert type(privkey) is bytes and len(privkey) == 32 self.privkey = privkey self.remote_pubkey = remote_pubkey - self.reader = reader - self.writer = writer async def handshake(self): hs = HandshakeState(self.remote_pubkey) diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 49226e46c..4f31dd6cb 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -1,5 +1,5 @@ import threading -from typing import NamedTuple, Iterable +from typing import NamedTuple, Iterable, TYPE_CHECKING import os from collections import defaultdict import asyncio @@ -11,6 +11,9 @@ from . import wallet from .storage import WalletStorage from .address_synchronizer import AddressSynchronizer +if TYPE_CHECKING: + from .network import Network + TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4) @@ -21,7 +24,7 @@ class LNWatcher(PrintError): # maybe we should disconnect from server in these cases verbosity_filter = 'W' - def __init__(self, network): + def __init__(self, network: 'Network'): self.network = network self.config = network.config path = os.path.join(network.config.path, "watcher_db") diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 754457096..472508062 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -3,7 +3,7 @@ import os from decimal import Decimal import random import time -from typing import Optional, Sequence, Tuple, List, Dict +from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING import threading import socket @@ -31,6 +31,11 @@ from .lnaddr import lndecode from .i18n import _ from .lnrouter import RouteEdge, is_route_sane_to_use +if TYPE_CHECKING: + from .network import Network + from .wallet import Abstract_Wallet + + NUM_PEERS_TARGET = 4 PEER_RETRY_INTERVAL = 600 # seconds PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds @@ -45,7 +50,7 @@ FALLBACK_NODE_LIST_MAINNET = ( class LNWorker(PrintError): - def __init__(self, wallet, network): + def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'): self.wallet = wallet self.sweep_address = wallet.get_receiving_address() self.network = network