Browse Source

some more type annotations that needed conditional imports

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 6 years ago
committed by ThomasV
parent
commit
f70e679aba
  1. 16
      electrum/lnbase.py
  2. 2
      electrum/lnchan.py
  3. 9
      electrum/lnchannelverifier.py
  4. 8
      electrum/lnonion.py
  5. 12
      electrum/lnrouter.py
  6. 35
      electrum/lntransport.py
  7. 7
      electrum/lnwatcher.py
  8. 9
      electrum/lnworker.py

16
electrum/lnbase.py

@ -10,7 +10,7 @@ import asyncio
import os import os
import time import time
from functools import partial from functools import partial
from typing import List, Tuple, Dict from typing import List, Tuple, Dict, TYPE_CHECKING
import traceback import traceback
import sys import sys
@ -31,10 +31,13 @@ from .lnutil import (Outpoint, LocalConfig, ChannelConfig,
funding_output_script, get_per_commitment_secret_from_seed, funding_output_script, get_per_commitment_secret_from_seed,
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures, secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily, LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily,
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED) get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED,
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed LightningPeerConnectionClosed, HandshakeFailed, LNPeerAddr)
from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge
from .lntransport import LNTransport from .lntransport import LNTransport, LNTransportBase
if TYPE_CHECKING:
from .lnworker import LNWorker
def channel_id_from_funding_tx(funding_txid, funding_index): def channel_id_from_funding_tx(funding_txid, funding_index):
@ -191,7 +194,8 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
class Peer(PrintError): class Peer(PrintError):
def __init__(self, lnworker, peer_addr, request_initial_sync=False, transport=None): def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr,
request_initial_sync=False, transport: LNTransportBase=None):
self.initialized = asyncio.Future() self.initialized = asyncio.Future()
self.transport = transport self.transport = transport
self.peer_addr = peer_addr self.peer_addr = peer_addr
@ -357,7 +361,7 @@ class Peer(PrintError):
def close_and_cleanup(self): def close_and_cleanup(self):
try: try:
if self.transport: if self.transport:
self.transport.writer.close() self.transport.close()
except: except:
pass pass
for chan in self.channels.values(): for chan in self.channels.values():

2
electrum/lnchan.py

@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict
import binascii import binascii
import json import json
from enum import Enum, auto from enum import Enum, auto
from typing import Optional from typing import Optional, Mapping, List
from .util import bfh, PrintError, bh2u from .util import bfh, PrintError, bh2u
from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS

9
electrum/lnchannelverifier.py

@ -25,6 +25,7 @@
import asyncio import asyncio
import threading import threading
from typing import TYPE_CHECKING
import aiorpcx import aiorpcx
@ -38,6 +39,10 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction from .transaction import Transaction
from .interface import GracefulDisconnect from .interface import GracefulDisconnect
if TYPE_CHECKING:
from .network import Network
from .lnrouter import ChannelDB
class LNChannelVerifier(NetworkJobOnDefaultServer): class LNChannelVerifier(NetworkJobOnDefaultServer):
""" Verify channel announcements for the Channel DB """ """ Verify channel announcements for the Channel DB """
@ -46,7 +51,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
# will start throttling us, making it even slower. one option would be to # will start throttling us, making it even slower. one option would be to
# spread it over multiple servers. # spread it over multiple servers.
def __init__(self, network, channel_db): def __init__(self, network: 'Network', channel_db: 'ChannelDB'):
NetworkJobOnDefaultServer.__init__(self, network) NetworkJobOnDefaultServer.__init__(self, network)
self.channel_db = channel_db self.channel_db = channel_db
self.lock = threading.Lock() self.lock = threading.Lock()
@ -105,7 +110,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id)) await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
#self.print_error('requested short_channel_id', bh2u(short_channel_id)) #self.print_error('requested short_channel_id', bh2u(short_channel_id))
async def verify_channel(self, block_height, tx_pos, short_channel_id): async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes):
# we are verifying channel announcements as they are from untrusted ln peers. # we are verifying channel announcements as they are from untrusted ln peers.
# we use electrum servers to do this. however we don't trust electrum servers either... # we use electrum servers to do this. however we don't trust electrum servers either...
try: try:

8
electrum/lnonion.py

@ -24,7 +24,7 @@
# SOFTWARE. # SOFTWARE.
import hashlib import hashlib
from typing import Sequence, List, Tuple, NamedTuple from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING
from enum import IntEnum, IntFlag from enum import IntEnum, IntFlag
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
@ -34,7 +34,9 @@ from . import ecc
from .crypto import sha256, hmac_oneshot from .crypto import sha256, hmac_oneshot
from .util import bh2u, profiler, xor_bytes, bfh from .util import bh2u, profiler, xor_bytes, bfh
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
from .lnrouter import RouteEdge
if TYPE_CHECKING:
from .lnrouter import RouteEdge
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04 HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
@ -186,7 +188,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes,
hmac=next_hmac) hmac=next_hmac)
def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_cltv: int) \ def calc_hops_data_for_payment(route: List['RouteEdge'], amount_msat: int, final_cltv: int) \
-> Tuple[List[OnionHopsDataSingle], int, int]: -> Tuple[List[OnionHopsDataSingle], int, int]:
"""Returns the hops_data to be used for constructing an onion packet, """Returns the hops_data to be used for constructing an onion packet,
and the amount_msat and cltv to be used on our immediate channel. and the amount_msat and cltv to be used on our immediate channel.

12
electrum/lnrouter.py

@ -27,8 +27,8 @@ import queue
import os import os
import json import json
import threading import threading
from collections import namedtuple, defaultdict from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING
import binascii import binascii
import base64 import base64
import asyncio import asyncio
@ -41,6 +41,10 @@ from .crypto import Hash
from . import ecc from . import ecc
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH
if TYPE_CHECKING:
from .lnchan import Channel
from .network import Network
class UnknownEvenFeatureBits(Exception): pass class UnknownEvenFeatureBits(Exception): pass
@ -272,7 +276,7 @@ class ChannelDB(JsonDB):
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
def __init__(self, network): def __init__(self, network: 'Network'):
self.network = network self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db') path = os.path.join(get_headers_dir(network.config), 'channel_db')
@ -597,7 +601,7 @@ class LNPathFinder(PrintError):
@profiler @profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, invoice_amount_msat: int,
my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]: my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
"""Return a path from nodeA to nodeB. """Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path. Returns a list of (node_id, short_channel_id) representing a path.

35
electrum/lntransport.py

@ -1,10 +1,11 @@
import hmac
import hashlib import hashlib
from asyncio import StreamReader, StreamWriter
import cryptography.hazmat.primitives.ciphers.aead as AEAD import cryptography.hazmat.primitives.ciphers.aead as AEAD
from .crypto import sha256 from .crypto import sha256, hmac_oneshot
from .lnutil import get_ecdh, privkey_to_pubkey from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed HandshakeFailed)
from . import ecc from . import ecc
from .util import bh2u from .util import bh2u
@ -49,13 +50,13 @@ def get_bolt8_hkdf(salt, ikm):
Return as two 32 byte fields. Return as two 32 byte fields.
""" """
#Extract #Extract
prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest() prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
assert len(prk) == 32 assert len(prk) == 32
#Expand #Expand
info = b"" info = b""
T0 = b"" T0 = b""
T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest() T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest() T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
assert len(T1 + T2) == 64 assert len(T1 + T2) == 64
return T1, T2 return T1, T2
@ -76,6 +77,11 @@ def create_ephemeral_key() -> (bytes, bytes):
return privkey.get_secret_bytes(), privkey.get_public_key_bytes() return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
class LNTransportBase: class LNTransportBase:
def __init__(self, reader: StreamReader, writer: StreamWriter):
self.reader = reader
self.writer = writer
def send_bytes(self, msg): def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big') l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l) lc = aead_encrypt(self.sk, self.sn(), b'', l)
@ -132,11 +138,14 @@ class LNTransportBase:
self.r_ck = ck self.r_ck = ck
self.s_ck = ck self.s_ck = ck
def close(self):
self.writer.close()
class LNResponderTransport(LNTransportBase): class LNResponderTransport(LNTransportBase):
def __init__(self, privkey, reader, writer): def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer)
self.privkey = privkey self.privkey = privkey
self.reader = reader
self.writer = writer
async def handshake(self, **kwargs): async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey)) hs = HandshakeState(privkey_to_pubkey(self.privkey))
@ -187,12 +196,12 @@ class LNResponderTransport(LNTransportBase):
return rs return rs
class LNTransport(LNTransportBase): class LNTransport(LNTransportBase):
def __init__(self, privkey, remote_pubkey, reader, writer): def __init__(self, privkey: bytes, remote_pubkey: bytes,
reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer)
assert type(privkey) is bytes and len(privkey) == 32 assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey self.privkey = privkey
self.remote_pubkey = remote_pubkey self.remote_pubkey = remote_pubkey
self.reader = reader
self.writer = writer
async def handshake(self): async def handshake(self):
hs = HandshakeState(self.remote_pubkey) hs = HandshakeState(self.remote_pubkey)

7
electrum/lnwatcher.py

@ -1,5 +1,5 @@
import threading import threading
from typing import NamedTuple, Iterable from typing import NamedTuple, Iterable, TYPE_CHECKING
import os import os
from collections import defaultdict from collections import defaultdict
import asyncio import asyncio
@ -11,6 +11,9 @@ from . import wallet
from .storage import WalletStorage from .storage import WalletStorage
from .address_synchronizer import AddressSynchronizer from .address_synchronizer import AddressSynchronizer
if TYPE_CHECKING:
from .network import Network
TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4) TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4)
@ -21,7 +24,7 @@ class LNWatcher(PrintError):
# maybe we should disconnect from server in these cases # maybe we should disconnect from server in these cases
verbosity_filter = 'W' verbosity_filter = 'W'
def __init__(self, network): def __init__(self, network: 'Network'):
self.network = network self.network = network
self.config = network.config self.config = network.config
path = os.path.join(network.config.path, "watcher_db") path = os.path.join(network.config.path, "watcher_db")

9
electrum/lnworker.py

@ -3,7 +3,7 @@ import os
from decimal import Decimal from decimal import Decimal
import random import random
import time import time
from typing import Optional, Sequence, Tuple, List, Dict from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
import threading import threading
import socket import socket
@ -31,6 +31,11 @@ from .lnaddr import lndecode
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use from .lnrouter import RouteEdge, is_route_sane_to_use
if TYPE_CHECKING:
from .network import Network
from .wallet import Abstract_Wallet
NUM_PEERS_TARGET = 4 NUM_PEERS_TARGET = 4
PEER_RETRY_INTERVAL = 600 # seconds PEER_RETRY_INTERVAL = 600 # seconds
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
@ -45,7 +50,7 @@ FALLBACK_NODE_LIST_MAINNET = (
class LNWorker(PrintError): class LNWorker(PrintError):
def __init__(self, wallet, network): def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
self.wallet = wallet self.wallet = wallet
self.sweep_address = wallet.get_receiving_address() self.sweep_address = wallet.get_receiving_address()
self.network = network self.network = network

Loading…
Cancel
Save