Browse Source

aiorpcx: pin certificates

3.3.3.1
Janus 7 years ago
committed by SomberNight
parent
commit
89a01a6463
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/__init__.py
  2. 404
      electrum/interface.py
  3. 37
      electrum/network.py

2
electrum/__init__.py

@ -4,7 +4,7 @@ from .wallet import Wallet
from .storage import WalletStorage from .storage import WalletStorage
from .coinchooser import COIN_CHOOSERS from .coinchooser import COIN_CHOOSERS
from .network import Network, pick_random_server from .network import Network, pick_random_server
from .interface import Connection, Interface from .interface import Interface
from .simple_config import SimpleConfig, get_config, set_config from .simple_config import SimpleConfig, get_config, set_config
from . import bitcoin from . import bitcoin
from . import transaction from . import transaction

404
electrum/interface.py

@ -28,343 +28,131 @@ import socket
import ssl import ssl
import sys import sys
import threading import threading
import time
import traceback import traceback
import aiorpcx
import asyncio
import requests import requests
from .util import print_error from .util import PrintError
ca_path = requests.certs.where() ca_path = requests.certs.where()
from . import util from . import util
from . import x509 from . import x509
from . import pem from . import pem
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
class Interface(PrintError):
def Connection(server, queue, config_path): def __init__(self, server, config_path, connecting):
"""Makes asynchronous connections to a remote Electrum server. self.connecting = connecting
Returns the running thread that is making the connection.
Once the thread has connected, it finishes, placing a tuple on the
queue of the form (server, socket), where socket is None if
connection failed.
"""
host, port, protocol = server.rsplit(':', 2)
if not protocol in 'st':
raise Exception('Unknown protocol: %s' % protocol)
c = TcpConnection(server, queue, config_path)
c.start()
return c
class TcpConnection(threading.Thread, util.PrintError):
verbosity_filter = 'i'
def __init__(self, server, queue, config_path):
threading.Thread.__init__(self)
self.config_path = config_path
self.queue = queue
self.server = server self.server = server
self.host, self.port, self.protocol = self.server.rsplit(':', 2) self.host, self.port, self.protocol = self.server.split(':')
self.host = str(self.host) self.config_path = config_path
self.port = int(self.port) self.cert_path = os.path.join(self.config_path, 'certs', self.host)
self.use_ssl = (self.protocol == 's') self.fut = asyncio.get_event_loop().create_task(self.run())
self.daemon = True
def diagnostic_name(self): def diagnostic_name(self):
return self.host return self.host
def check_host_name(self, peercert, name): async def is_server_ca_signed(self, sslc):
"""Simple certificate/host name checker. Returns True if the
certificate matches, False otherwise. Does not support
wildcards."""
# Check that the peer has supplied a certificate.
# None/{} is not acceptable.
if not peercert:
return False
if 'subjectAltName' in peercert:
for typ, val in peercert["subjectAltName"]:
if typ == "DNS" and val == name:
return True
else:
# Only check the subject DN if there is no subject alternative
# name.
cn = None
for attr, val in peercert["subject"]:
# Use most-specific (last) commonName attribute.
if attr == "commonName":
cn = val
if cn is not None:
return cn == name
return False
def get_simple_socket(self):
try: try:
l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) await self.open_session(sslc, do_sleep=False)
except socket.gaierror: except ssl.SSLError as e:
self.print_error("cannot resolve hostname") assert e.reason == 'CERTIFICATE_VERIFY_FAILED'
return return False
e = None return True
for res in l:
try:
s = socket.socket(res[0], socket.SOCK_STREAM)
s.settimeout(10)
s.connect(res[4])
s.settimeout(2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
return s
except BaseException as _e:
e = _e
continue
else:
self.print_error("failed to connect", str(e))
@staticmethod
def get_ssl_context(cert_reqs, ca_certs):
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs)
context.check_hostname = False
context.verify_mode = cert_reqs
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
return context
def get_socket(self):
if self.use_ssl:
cert_path = os.path.join(self.config_path, 'certs', self.host)
if not os.path.exists(cert_path):
is_new = True
s = self.get_simple_socket()
if s is None:
return
# try with CA first
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path)
s = context.wrap_socket(s, do_handshake_on_connect=True)
except ssl.SSLError as e:
self.print_error(e)
except:
return
else:
try:
peer_cert = s.getpeercert()
except OSError:
return
if self.check_host_name(peer_cert, self.host):
self.print_error("SSL certificate signed by CA")
return s
# get server certificate.
# Do not use ssl.get_server_certificate because it does not work with proxy
s = self.get_simple_socket()
if s is None:
return
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None)
s = context.wrap_socket(s)
except ssl.SSLError as e:
self.print_error("SSL error retrieving SSL certificate:", e)
return
except:
return
try:
dercert = s.getpeercert(True)
except OSError:
return
s.close()
cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
temporary_path = cert_path + '.temp'
util.assert_datadir_available(self.config_path)
with open(temporary_path, "w", encoding='utf-8') as f:
f.write(cert)
f.flush()
os.fsync(f.fileno())
else:
is_new = False
s = self.get_simple_socket() @util.aiosafe
if s is None: async def run(self):
if self.protocol != 's':
await self.open_session(None, execute_after_connect=lambda: self.connecting.remove(self.server))
return return
if self.use_ssl: ca_sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
try: exists = os.path.exists(self.cert_path)
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, if exists:
ca_certs=(temporary_path if is_new else cert_path)) with open(self.cert_path, 'r') as f:
s = context.wrap_socket(s, do_handshake_on_connect=True) contents = f.read()
except socket.timeout: if contents != '': # if not CA signed
self.print_error('timeout')
return
except ssl.SSLError as e:
self.print_error("SSL error:", e)
if e.errno != 1:
return
if is_new:
rej = cert_path + '.rej'
if os.path.exists(rej):
os.unlink(rej)
os.rename(temporary_path, rej)
else:
util.assert_datadir_available(self.config_path)
with open(cert_path, encoding='utf-8') as f:
cert = f.read()
try: try:
b = pem.dePem(cert, 'CERTIFICATE') b = pem.dePem(contents, 'CERTIFICATE')
except SyntaxError:
exists = False
else:
x = x509.X509(b) x = x509.X509(b)
except: try:
traceback.print_exc(file=sys.stderr) x.check_date()
self.print_error("wrong certificate") except x509.CertificateError:
return self.print_error("certificate has expired:", self.cert_path)
try: os.unlink(self.cert_path)
x.check_date() exists = False
except: if not exists:
self.print_error("certificate has expired:", cert_path) ca_signed = await self.is_server_ca_signed(ca_sslc)
os.unlink(cert_path) if ca_signed:
return with open(self.cert_path, 'w') as f:
self.print_error("wrong certificate") # empty file means this is CA signed, not self-signed
if e.errno == 104: f.write('')
return else:
return await self.save_certificate()
except BaseException as e: siz = os.stat(self.cert_path).st_size
self.print_error(e) if siz == 0: # if CA signed
traceback.print_exc(file=sys.stderr) sslc = ca_sslc
return else:
sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
if is_new: sslc.check_hostname = 0
self.print_error("saving certificate") await self.open_session(sslc, execute_after_connect=lambda: self.connecting.remove(self.server))
os.rename(temporary_path, cert_path)
async def save_certificate(self):
return s if not os.path.exists(self.cert_path):
# we may need to retry this a few times, in case the handshake hasn't completed
def run(self): for _ in range(10):
socket = self.get_socket() dercert = await self.get_certificate()
if socket: if dercert:
self.print_error("connected") self.print_error("succeeded in getting cert")
self.queue.put((self.server, socket)) with open(self.cert_path, 'w') as f:
cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug
class Interface(util.PrintError): cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
"""The Interface class handles a socket connected to a single remote f.write(cert)
Electrum server. Its exposed API is: # even though close flushes we can't fsync when closed.
# and we must flush before fsyncing, cause flush flushes to OS buffer
- Member functions close(), fileno(), get_responses(), has_timed_out(), # fsync writes to OS buffer to disk
ping_required(), queue_request(), send_requests() f.flush()
- Member variable server. os.fsync(f.fileno())
""" break
await asyncio.sleep(1)
def __init__(self, server, socket): assert False, "could not get certificate"
self.server = server
self.host, _, _ = server.rsplit(':', 2)
self.socket = socket
self.pipe = util.SocketPipe(socket)
self.pipe.set_timeout(0.0) # Don't wait for data
# Dump network messages. Set at runtime from the console.
self.debug = False
self.unsent_requests = []
self.unanswered_requests = {}
self.last_send = time.time()
self.closed_remotely = False
def diagnostic_name(self):
return self.host
def fileno(self):
# Needed for select
return self.socket.fileno()
def close(self):
if not self.closed_remotely:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
def queue_request(self, *args): # method, params, _id
'''Queue a request, later to be send with send_requests when the
socket is available for writing.
'''
self.request_time = time.time()
self.unsent_requests.append(args)
def num_requests(self):
'''Keep unanswered requests below 100'''
n = 100 - len(self.unanswered_requests)
return min(n, len(self.unsent_requests))
def send_requests(self): async def get_certificate(self):
'''Sends queued requests. Returns False on failure.''' sslc = ssl.SSLContext()
self.last_send = time.time()
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests()
wire_requests = self.unsent_requests[0:n]
try: try:
self.pipe.send_all([make_dict(*r) for r in wire_requests]) async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
except BaseException as e: return session.transport._ssl_protocol._sslpipe._sslobj.getpeercert(True)
self.print_error("pipe send error:", e) except ValueError:
return False return None
self.unsent_requests = self.unsent_requests[n:]
for request in wire_requests: async def open_session(self, sslc, do_sleep=True, execute_after_connect=lambda: None):
if self.debug: async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
self.print_error("-->", request) ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
self.unanswered_requests[request[2]] = request print(ver)
return True connect_hook_executed = False
while do_sleep:
def ping_required(self): if not connect_hook_executed:
'''Returns True if a ping should be sent.''' connect_hook_executed = True
return time.time() - self.last_send > 300 execute_after_connect()
await asyncio.wait_for(session.send_request('server.ping'), 5)
await asyncio.sleep(300)
def has_timed_out(self): def has_timed_out(self):
'''Returns True if the interface has timed out.''' return self.fut.done()
if (self.unanswered_requests and time.time() - self.request_time > 10
and self.pipe.idle_time() > 10):
self.print_error("timeout", len(self.unanswered_requests))
return True
return False
def get_responses(self):
'''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior
subscriptions, so request is None and there is no 'id' member.
Otherwise it is a response, which has an 'id' member and a
corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear.
'''
responses = []
while True:
try:
response = self.pipe.get()
except util.timeout:
break
if not type(response) is dict:
responses.append((None, None))
if response is None:
self.closed_remotely = True
self.print_error("connection closed remotely")
break
if self.debug:
self.print_error("<--", response)
wire_id = response.get('id', None)
if wire_id is None: # Notification
responses.append((None, response))
else:
request = self.unanswered_requests.pop(wire_id, None)
if request:
responses.append((request, response))
else:
self.print_error("unknown wire ID", wire_id)
responses.append((None, None)) # Signal
break
return responses def queue_request(self, method, params, msg_id):
pass
def close(self):
self.fut.cancel()
def check_cert(host, cert): def check_cert(host, cert):
try: try:

37
electrum/network.py

@ -47,38 +47,14 @@ from . import blockchain
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
from .i18n import _ from .i18n import _
from .blockchain import InvalidHeader from .blockchain import InvalidHeader
from .interface import Interface
import aiorpcx, asyncio, ssl import asyncio
import concurrent.futures import concurrent.futures
NODES_RETRY_INTERVAL = 60 NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10 SERVER_RETRY_INTERVAL = 10
class Interface(PrintError):
@util.aiosafe
async def run(self):
self.host, self.port, self.protocol = self.server.split(':')
sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) if self.protocol == 's' else None
async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
print(ver)
while True:
print("sleeping")
await asyncio.sleep(1)
def __init__(self, server):
self.exception = None
self.server = server
self.fut = asyncio.get_event_loop().create_task(self.run())
def has_timed_out(self):
return self.fut.done()
def queue_request(self, method, params, msg_id):
pass
def close(self):
self.fut.cancel()
def parse_servers(result): def parse_servers(result):
""" parse servers list into dict format""" """ parse servers list into dict format"""
@ -539,7 +515,7 @@ class Network(PrintError):
self.close_interface(self.interface) self.close_interface(self.interface)
assert self.interface is None assert self.interface is None
assert not self.interfaces assert not self.interfaces
self.connecting = set() self.connecting.clear()
# Get a new queue - no old pending connections thanks! # Get a new queue - no old pending connections thanks!
self.socket_queue = queue.Queue() self.socket_queue = queue.Queue()
@ -810,7 +786,7 @@ class Network(PrintError):
def new_interface(self, server): def new_interface(self, server):
# todo: get tip first, then decide which checkpoint to use. # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server) self.add_recent_server(server)
interface = Interface(server) interface = Interface(server, self.config.path, self.connecting)
interface.blockchain = None interface.blockchain = None
interface.tip_header = None interface.tip_header = None
interface.tip = 0 interface.tip = 0
@ -1368,9 +1344,12 @@ class Network(PrintError):
for k, i in self.interfaces.items(): for k, i in self.interfaces.items():
if i.has_timed_out(): if i.has_timed_out():
remove.append(k) remove.append(k)
changed = False
for k in remove: for k in remove:
self.connection_down(k) self.connection_down(k)
changed = True
for i in range(self.num_server - len(self.interfaces)): for i in range(self.num_server - len(self.interfaces)):
self.start_random_interface() self.start_random_interface()
self.notify('updated') changed = True
if changed: self.notify('updated')
await asyncio.sleep(1) await asyncio.sleep(1)

Loading…
Cancel
Save