Browse Source

some more type annotations that needed conditional imports

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 6 years ago
committed by ThomasV
parent
commit
f70e679aba
  1. 16
      electrum/lnbase.py
  2. 2
      electrum/lnchan.py
  3. 9
      electrum/lnchannelverifier.py
  4. 8
      electrum/lnonion.py
  5. 12
      electrum/lnrouter.py
  6. 35
      electrum/lntransport.py
  7. 7
      electrum/lnwatcher.py
  8. 9
      electrum/lnworker.py

16
electrum/lnbase.py

@ -10,7 +10,7 @@ import asyncio
import os
import time
from functools import partial
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, TYPE_CHECKING
import traceback
import sys
@ -31,10 +31,13 @@ from .lnutil import (Outpoint, LocalConfig, ChannelConfig,
funding_output_script, get_per_commitment_secret_from_seed,
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily,
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED)
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED,
LightningPeerConnectionClosed, HandshakeFailed, LNPeerAddr)
from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge
from .lntransport import LNTransport
from .lntransport import LNTransport, LNTransportBase
if TYPE_CHECKING:
from .lnworker import LNWorker
def channel_id_from_funding_tx(funding_txid, funding_index):
@ -191,7 +194,8 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
class Peer(PrintError):
def __init__(self, lnworker, peer_addr, request_initial_sync=False, transport=None):
def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr,
request_initial_sync=False, transport: LNTransportBase=None):
self.initialized = asyncio.Future()
self.transport = transport
self.peer_addr = peer_addr
@ -357,7 +361,7 @@ class Peer(PrintError):
def close_and_cleanup(self):
try:
if self.transport:
self.transport.writer.close()
self.transport.close()
except:
pass
for chan in self.channels.values():

2
electrum/lnchan.py

@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict
import binascii
import json
from enum import Enum, auto
from typing import Optional
from typing import Optional, Mapping, List
from .util import bfh, PrintError, bh2u
from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS

9
electrum/lnchannelverifier.py

@ -25,6 +25,7 @@
import asyncio
import threading
from typing import TYPE_CHECKING
import aiorpcx
@ -38,6 +39,10 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction
from .interface import GracefulDisconnect
if TYPE_CHECKING:
from .network import Network
from .lnrouter import ChannelDB
class LNChannelVerifier(NetworkJobOnDefaultServer):
""" Verify channel announcements for the Channel DB """
@ -46,7 +51,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
# will start throttling us, making it even slower. one option would be to
# spread it over multiple servers.
def __init__(self, network, channel_db):
def __init__(self, network: 'Network', channel_db: 'ChannelDB'):
NetworkJobOnDefaultServer.__init__(self, network)
self.channel_db = channel_db
self.lock = threading.Lock()
@ -105,7 +110,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
#self.print_error('requested short_channel_id', bh2u(short_channel_id))
async def verify_channel(self, block_height, tx_pos, short_channel_id):
async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes):
# we are verifying channel announcements as they are from untrusted ln peers.
# we use electrum servers to do this. however we don't trust electrum servers either...
try:

8
electrum/lnonion.py

@ -24,7 +24,7 @@
# SOFTWARE.
import hashlib
from typing import Sequence, List, Tuple, NamedTuple
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING
from enum import IntEnum, IntFlag
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
@ -34,7 +34,9 @@ from . import ecc
from .crypto import sha256, hmac_oneshot
from .util import bh2u, profiler, xor_bytes, bfh
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
from .lnrouter import RouteEdge
if TYPE_CHECKING:
from .lnrouter import RouteEdge
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
@ -186,7 +188,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes,
hmac=next_hmac)
def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_cltv: int) \
def calc_hops_data_for_payment(route: List['RouteEdge'], amount_msat: int, final_cltv: int) \
-> Tuple[List[OnionHopsDataSingle], int, int]:
"""Returns the hops_data to be used for constructing an onion packet,
and the amount_msat and cltv to be used on our immediate channel.

12
electrum/lnrouter.py

@ -27,8 +27,8 @@ import queue
import os
import json
import threading
from collections import namedtuple, defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple
from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING
import binascii
import base64
import asyncio
@ -41,6 +41,10 @@ from .crypto import Hash
from . import ecc
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH
if TYPE_CHECKING:
from .lnchan import Channel
from .network import Network
class UnknownEvenFeatureBits(Exception): pass
@ -272,7 +276,7 @@ class ChannelDB(JsonDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network):
def __init__(self, network: 'Network'):
self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db')
@ -597,7 +601,7 @@ class LNPathFinder(PrintError):
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int,
my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
"""Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path.

35
electrum/lntransport.py

@ -1,10 +1,11 @@
import hmac
import hashlib
from asyncio import StreamReader, StreamWriter
import cryptography.hazmat.primitives.ciphers.aead as AEAD
from .crypto import sha256
from .lnutil import get_ecdh, privkey_to_pubkey
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
from .crypto import sha256, hmac_oneshot
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
HandshakeFailed)
from . import ecc
from .util import bh2u
@ -49,13 +50,13 @@ def get_bolt8_hkdf(salt, ikm):
Return as two 32 byte fields.
"""
#Extract
prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest()
prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
assert len(prk) == 32
#Expand
info = b""
T0 = b""
T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest()
T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest()
T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
assert len(T1 + T2) == 64
return T1, T2
@ -76,6 +77,11 @@ def create_ephemeral_key() -> (bytes, bytes):
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
class LNTransportBase:
def __init__(self, reader: StreamReader, writer: StreamWriter):
self.reader = reader
self.writer = writer
def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
@ -132,11 +138,14 @@ class LNTransportBase:
self.r_ck = ck
self.s_ck = ck
def close(self):
self.writer.close()
class LNResponderTransport(LNTransportBase):
def __init__(self, privkey, reader, writer):
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer)
self.privkey = privkey
self.reader = reader
self.writer = writer
async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey))
@ -187,12 +196,12 @@ class LNResponderTransport(LNTransportBase):
return rs
class LNTransport(LNTransportBase):
def __init__(self, privkey, remote_pubkey, reader, writer):
def __init__(self, privkey: bytes, remote_pubkey: bytes,
reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer)
assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey
self.remote_pubkey = remote_pubkey
self.reader = reader
self.writer = writer
async def handshake(self):
hs = HandshakeState(self.remote_pubkey)

7
electrum/lnwatcher.py

@ -1,5 +1,5 @@
import threading
from typing import NamedTuple, Iterable
from typing import NamedTuple, Iterable, TYPE_CHECKING
import os
from collections import defaultdict
import asyncio
@ -11,6 +11,9 @@ from . import wallet
from .storage import WalletStorage
from .address_synchronizer import AddressSynchronizer
if TYPE_CHECKING:
from .network import Network
TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4)
@ -21,7 +24,7 @@ class LNWatcher(PrintError):
# maybe we should disconnect from server in these cases
verbosity_filter = 'W'
def __init__(self, network):
def __init__(self, network: 'Network'):
self.network = network
self.config = network.config
path = os.path.join(network.config.path, "watcher_db")

9
electrum/lnworker.py

@ -3,7 +3,7 @@ import os
from decimal import Decimal
import random
import time
from typing import Optional, Sequence, Tuple, List, Dict
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
import threading
import socket
@ -31,6 +31,11 @@ from .lnaddr import lndecode
from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use
if TYPE_CHECKING:
from .network import Network
from .wallet import Abstract_Wallet
NUM_PEERS_TARGET = 4
PEER_RETRY_INTERVAL = 600 # seconds
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
@ -45,7 +50,7 @@ FALLBACK_NODE_LIST_MAINNET = (
class LNWorker(PrintError):
def __init__(self, wallet, network):
def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
self.wallet = wallet
self.sweep_address = wallet.get_receiving_address()
self.network = network

Loading…
Cancel
Save