Browse Source

network dns hacks: split from network.py into its own file

hard-fail-on-bad-server-string
SomberNight 5 years ago
committed by ghost43
parent
commit
11452722af
  1. 100
      electrum/dns_hacks.py
  2. 77
      electrum/network.py

100
electrum/dns_hacks.py

@ -0,0 +1,100 @@
# Copyright (C) 2020 The Electrum developers
# Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
import sys
import socket
import concurrent
from concurrent import futures
import ipaddress
from typing import Optional
import dns
import dns.resolver
from .logging import get_logger
_logger = get_logger(__name__)
_dns_threads_executor = None # type: Optional[concurrent.futures.Executor]
def configure_dns_depending_on_proxy(is_proxy: bool) -> None:
# Store this somewhere so we can un-monkey-patch:
if not hasattr(socket, "_getaddrinfo"):
socket._getaddrinfo = socket.getaddrinfo
if is_proxy:
# prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy
socket.getaddrinfo = lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))]
else:
if sys.platform == 'win32':
# On Windows, socket.getaddrinfo takes a mutex, and might hold it for up to 10 seconds
# when dns-resolving. To speed it up drastically, we resolve dns ourselves, outside that lock.
# See https://github.com/spesmilo/electrum/issues/4421
_prepare_windows_dns_hack()
socket.getaddrinfo = _fast_getaddrinfo
else:
socket.getaddrinfo = socket._getaddrinfo
def _prepare_windows_dns_hack():
# enable dns cache
resolver = dns.resolver.get_default_resolver()
if resolver.cache is None:
resolver.cache = dns.resolver.Cache()
# prepare threads
global _dns_threads_executor
if _dns_threads_executor is None:
_dns_threads_executor = concurrent.futures.ThreadPoolExecutor(max_workers=20,
thread_name_prefix='dns_resolver')
def _fast_getaddrinfo(host, *args, **kwargs):
def needs_dns_resolving(host):
try:
ipaddress.ip_address(host)
return False # already valid IP
except ValueError:
pass # not an IP
if str(host) in ('localhost', 'localhost.',):
return False
return True
def resolve_with_dnspython(host):
addrs = []
expected_errors = (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer,
concurrent.futures.CancelledError, concurrent.futures.TimeoutError)
ipv6_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.AAAA)
ipv4_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.A)
# try IPv6
try:
answers = ipv6_fut.result()
addrs += [str(answer) for answer in answers]
except expected_errors as e:
pass
except BaseException as e:
_logger.info(f'dnspython failed to resolve dns (AAAA) for {repr(host)} with error: {repr(e)}')
# try IPv4
try:
answers = ipv4_fut.result()
addrs += [str(answer) for answer in answers]
except expected_errors as e:
# dns failed for some reason, e.g. dns.resolver.NXDOMAIN this is normal.
# Simply report back failure; except if we already have some results.
if not addrs:
raise socket.gaierror(11001, 'getaddrinfo failed') from e
except BaseException as e:
# Possibly internal error in dnspython :( see #4483 and #5638
_logger.info(f'dnspython failed to resolve dns (A) for {repr(host)} with error: {repr(e)}')
if addrs:
return addrs
# Fall back to original socket.getaddrinfo to resolve dns.
return [host]
addrs = [host]
if needs_dns_resolving(host):
addrs = resolve_with_dnspython(host)
list_of_list_of_socketinfos = [socket._getaddrinfo(addr, *args, **kwargs) for addr in addrs]
list_of_socketinfos = [item for lst in list_of_list_of_socketinfos for item in lst]
return list_of_socketinfos

77
electrum/network.py

@ -31,15 +31,12 @@ import threading
import socket import socket
import json import json
import sys import sys
import ipaddress
import asyncio import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable
import traceback import traceback
import concurrent import concurrent
from concurrent import futures from concurrent import futures
import dns
import dns.resolver
import aiorpcx import aiorpcx
from aiorpcx import TaskGroup from aiorpcx import TaskGroup
from aiohttp import ClientResponse from aiohttp import ClientResponse
@ -53,6 +50,7 @@ from .bitcoin import COIN
from . import constants from . import constants
from . import blockchain from . import blockchain
from . import bitcoin from . import bitcoin
from . import dns_hacks
from .transaction import Transaction from .transaction import Transaction
from .blockchain import Blockchain, HEADER_SIZE from .blockchain import Blockchain, HEADER_SIZE
from .interface import (Interface, serialize_server, deserialize_server, from .interface import (Interface, serialize_server, deserialize_server,
@ -228,8 +226,6 @@ class UntrustedServerReturnedError(NetworkException):
return f"<UntrustedServerReturnedError original_exception: {repr(self.original_exception)}>" return f"<UntrustedServerReturnedError original_exception: {repr(self.original_exception)}>"
_dns_threads_executor = None # type: Optional[concurrent.futures.Executor]
_INSTANCE = None _INSTANCE = None
@ -557,77 +553,10 @@ class Network(Logger):
def _set_proxy(self, proxy: Optional[dict]): def _set_proxy(self, proxy: Optional[dict]):
self.proxy = proxy self.proxy = proxy
# Store these somewhere so we can un-monkey-patch dns_hacks.configure_dns_depending_on_proxy(bool(proxy))
if not hasattr(socket, "_getaddrinfo"): self.logger.info(f'setting proxy {proxy}')
socket._getaddrinfo = socket.getaddrinfo
if proxy:
self.logger.info(f'setting proxy {proxy}')
# prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy
socket.getaddrinfo = lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))]
else:
if sys.platform == 'win32':
# On Windows, socket.getaddrinfo takes a mutex, and might hold it for up to 10 seconds
# when dns-resolving. To speed it up drastically, we resolve dns ourselves, outside that lock.
# see #4421
resolver = dns.resolver.get_default_resolver()
if resolver.cache is None:
resolver.cache = dns.resolver.Cache()
global _dns_threads_executor
if _dns_threads_executor is None:
_dns_threads_executor = concurrent.futures.ThreadPoolExecutor(max_workers=20)
socket.getaddrinfo = self._fast_getaddrinfo
else:
socket.getaddrinfo = socket._getaddrinfo
self.trigger_callback('proxy_set', self.proxy) self.trigger_callback('proxy_set', self.proxy)
@staticmethod
def _fast_getaddrinfo(host, *args, **kwargs):
def needs_dns_resolving(host):
try:
ipaddress.ip_address(host)
return False # already valid IP
except ValueError:
pass # not an IP
if str(host) in ('localhost', 'localhost.',):
return False
return True
def resolve_with_dnspython(host):
addrs = []
expected_errors = (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer,
concurrent.futures.CancelledError, concurrent.futures.TimeoutError)
ipv6_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.AAAA)
ipv4_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.A)
# try IPv6
try:
answers = ipv6_fut.result()
addrs += [str(answer) for answer in answers]
except expected_errors as e:
pass
except BaseException as e:
_logger.info(f'dnspython failed to resolve dns (AAAA) for {repr(host)} with error: {repr(e)}')
# try IPv4
try:
answers = ipv4_fut.result()
addrs += [str(answer) for answer in answers]
except expected_errors as e:
# dns failed for some reason, e.g. dns.resolver.NXDOMAIN this is normal.
# Simply report back failure; except if we already have some results.
if not addrs:
raise socket.gaierror(11001, 'getaddrinfo failed') from e
except BaseException as e:
# Possibly internal error in dnspython :( see #4483 and #5638
_logger.info(f'dnspython failed to resolve dns (A) for {repr(host)} with error: {repr(e)}')
if addrs:
return addrs
# Fall back to original socket.getaddrinfo to resolve dns.
return [host]
addrs = [host]
if needs_dns_resolving(host):
addrs = resolve_with_dnspython(host)
list_of_list_of_socketinfos = [socket._getaddrinfo(addr, *args, **kwargs) for addr in addrs]
list_of_socketinfos = [item for lst in list_of_list_of_socketinfos for item in lst]
return list_of_socketinfos
@log_exceptions @log_exceptions
async def set_parameters(self, net_params: NetworkParameters): async def set_parameters(self, net_params: NetworkParameters):
proxy = net_params.proxy proxy = net_params.proxy

Loading…
Cancel
Save