diff --git a/lib/interface.py b/lib/interface.py index db2ad7d9b..cb27b1c4a 100644 --- a/lib/interface.py +++ b/lib/interface.py @@ -55,15 +55,12 @@ class TcpInterface(threading.Thread): self.daemon = True self.config = config if config is not None else SimpleConfig() self.lock = threading.Lock() - self.is_connected = False self.debug = False # dump network messages. can be changed at runtime using the console self.message_id = 0 self.unanswered_requests = {} - # are we waiting for a pong? self.is_ping = False - # parse server self.server = server self.host, self.port, self.protocol = self.server.split(':') @@ -72,6 +69,12 @@ class TcpInterface(threading.Thread): self.proxy = self.parse_proxy_options(self.config.get('proxy')) if self.proxy: self.proxy_mode = proxy_modes.index(self.proxy["mode"]) + 1 + socks.setdefaultproxy(self.proxy_mode, self.proxy["host"], int(self.proxy["port"])) + socket.socket = socks.socksocket + # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy + def getaddrinfo(*args): + return [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))] + socket.getaddrinfo = getaddrinfo def process_response(self, response): @@ -116,54 +119,48 @@ class TcpInterface(threading.Thread): queue.put((self, {'method':method, 'params':params, 'result':result, 'id':_id})) - def get_socket(self): + def get_simple_socket(self): + try: + l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + print_error("error: cannot resolve", self.host) + return + for res in l: + try: + s = socket.socket(res[0], socket.SOCK_STREAM) + s.connect(res[4]) + return s + except: + continue + else: + print_error("failed to connect", self.host, self.port) - if self.proxy is not None: - socks.setdefaultproxy(self.proxy_mode, self.proxy["host"], int(self.proxy["port"])) - socket.socket = socks.socksocket - # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy - def getaddrinfo(*args): - return [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))] - socket.getaddrinfo = getaddrinfo + def get_socket(self): if self.use_ssl: cert_path = os.path.join( self.config.path, 'certs', self.host) if not os.path.exists(cert_path): is_new = True - # get server certificate. - # Do not use ssl.get_server_certificate because it does not work with proxy - try: - l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.gaierror: - print_error("error: cannot resolve", self.host) + s = self.get_simple_socket() + if s is None: return - for res in l: - try: - s = socket.socket( res[0], socket.SOCK_STREAM ) - s.connect(res[4]) - except: - s = None - continue - - # first try with ca - try: - ca_certs = os.path.join(self.config.path, 'ca', 'ca-bundle.crt') - s = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_SSLv3, cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs, do_handshake_on_connect=True) - print_error("SSL with ca:", self.host) - return s - except ssl.SSLError, e: - pass - - try: - s = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_SSLv3, cert_reqs=ssl.CERT_NONE, ca_certs=None) - except ssl.SSLError, e: - print_error("SSL error retrieving SSL certificate:", self.host, e) - s = None - - break + # try with CA first + try: + ca_certs = os.path.join(self.config.path, 'ca', 'ca-bundle.crt') + s = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_SSLv3, cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs, do_handshake_on_connect=True) + print_error("SSL with ca:", self.host) + return s + except ssl.SSLError, e: + pass - if s is None: + # get server certificate. + # Do not use ssl.get_server_certificate because it does not work with proxy + s = self.get_simple_socket() + try: + s = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_SSLv3, cert_reqs=ssl.CERT_NONE, ca_certs=None) + except ssl.SSLError, e: + print_error("SSL error retrieving SSL certificate:", self.host, e) return dercert = s.getpeercert(True) @@ -174,31 +171,16 @@ class TcpInterface(threading.Thread): temporary_path = cert_path + '.temp' with open(temporary_path,"w") as f: f.write(cert) - else: is_new = False - try: - addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.gaierror: - print_error("error: cannot resolve", self.host) - return - - for res in addrinfo: - try: - s = socket.socket( res[0], socket.SOCK_STREAM ) - s.settimeout(2) - s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - s.connect(res[4]) - except: - s = None - continue - break - + s = self.get_simple_socket() if s is None: - print_error("failed to connect", self.host, self.port) return + s.settimeout(2) + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if self.use_ssl: try: s = ssl.wrap_socket(s,