Browse Source

locks in network.py

3.2.x
SomberNight 7 years ago
parent
commit
cd41a451f6
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 203
      lib/network.py

203
lib/network.py

@ -171,7 +171,7 @@ class Network(util.DaemonThread):
util.DaemonThread.__init__(self)
self.config = SimpleConfig(config) if isinstance(config, dict) else config
self.num_server = 10 if not self.config.get('oneserver') else 0
self.blockchains = blockchain.read_blockchains(self.config)
self.blockchains = blockchain.read_blockchains(self.config) # note: needs self.blockchains_lock
self.print_error("blockchains", self.blockchains.keys())
self.blockchain_index = config.get('blockchain_index', 0)
if self.blockchain_index not in self.blockchains.keys():
@ -187,27 +187,35 @@ class Network(util.DaemonThread):
self.default_server = None
if not self.default_server:
self.default_server = pick_random_server()
self.lock = threading.Lock()
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
self.interface_lock = threading.RLock() # <- re-entrant
self.callback_lock = threading.Lock()
self.pending_sends_lock = threading.Lock()
self.recent_servers_lock = threading.RLock() # <- re-entrant
self.subscribed_addresses_lock = threading.Lock()
self.blockchains_lock = threading.Lock()
self.pending_sends = []
self.message_id = 0
self.debug = False
self.irc_servers = {} # returned by interface (list from irc)
self.recent_servers = self.read_recent_servers()
self.recent_servers = self.read_recent_servers() # note: needs self.recent_servers_lock
self.banner = ''
self.donation_address = ''
self.relay_fee = None
# callbacks passed with subscriptions
self.subscriptions = defaultdict(list)
self.sub_cache = {}
self.subscriptions = defaultdict(list) # note: needs self.callback_lock
self.sub_cache = {} # note: needs self.interface_lock
# callbacks set by the GUI
self.callbacks = defaultdict(list)
self.callbacks = defaultdict(list) # note: needs self.callback_lock
dir_path = os.path.join( self.config.path, 'certs')
util.make_dir(dir_path)
# subscriptions and requests
self.subscribed_addresses = set()
self.subscribed_addresses = set() # note: needs self.subscribed_addresses_lock
self.h2addr = {}
# Requests from client we've not seen a response to
self.unanswered_requests = {}
@ -217,8 +225,8 @@ class Network(util.DaemonThread):
# kick off the network. interface is the main server we are currently
# communicating with. interfaces is the set of servers we are connecting
# to or have an ongoing connection with
self.interface = None
self.interfaces = {}
self.interface = None # note: needs self.interface_lock
self.interfaces = {} # note: needs self.interface_lock
self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set()
self.requested_chunks = set()
@ -226,19 +234,31 @@ class Network(util.DaemonThread):
self.start_network(deserialize_server(self.default_server)[2],
deserialize_proxy(self.config.get('proxy')))
def with_interface_lock(func):
def func_wrapper(self, *args, **kwargs):
with self.interface_lock:
return func(self, *args, **kwargs)
return func_wrapper
def with_recent_servers_lock(func):
def func_wrapper(self, *args, **kwargs):
with self.recent_servers_lock:
return func(self, *args, **kwargs)
return func_wrapper
def register_callback(self, callback, events):
with self.lock:
with self.callback_lock:
for event in events:
self.callbacks[event].append(callback)
def unregister_callback(self, callback):
with self.lock:
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
def trigger_callback(self, event, *args):
with self.lock:
with self.callback_lock:
callbacks = self.callbacks[event][:]
[callback(event, *args) for callback in callbacks]
@ -253,6 +273,7 @@ class Network(util.DaemonThread):
except:
return []
@with_recent_servers_lock
def save_recent_servers(self):
if not self.config.path:
return
@ -264,6 +285,7 @@ class Network(util.DaemonThread):
except:
pass
@with_interface_lock
def get_server_height(self):
return self.interface.tip if self.interface else 0
@ -291,11 +313,15 @@ class Network(util.DaemonThread):
def is_up_to_date(self):
return self.unanswered_requests == {}
@with_interface_lock
def queue_request(self, method, params, interface=None):
# If you want to queue a request on any interface it must go
# through this function so message ids are properly tracked
if interface is None:
interface = self.interface
if interface is None:
self.print_error('warning: dropping request', method, params)
return
message_id = self.message_id
self.message_id += 1
if self.debug:
@ -303,7 +329,9 @@ class Network(util.DaemonThread):
interface.queue_request(method, params, message_id)
return message_id
@with_interface_lock
def send_subscriptions(self):
assert self.interface
self.print_error('sending subscriptions to', self.interface.server, len(self.unanswered_requests), len(self.subscribed_addresses))
self.sub_cache.clear()
# Resend unanswered requests
@ -317,8 +345,9 @@ class Network(util.DaemonThread):
self.queue_request('server.peers.subscribe', [])
self.request_fee_estimates()
self.queue_request('blockchain.relayfee', [])
for h in list(self.subscribed_addresses):
self.queue_request('blockchain.scripthash.subscribe', [h])
with self.subscribed_addresses_lock:
for h in self.subscribed_addresses:
self.queue_request('blockchain.scripthash.subscribe', [h])
def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
@ -358,10 +387,12 @@ class Network(util.DaemonThread):
if self.is_connected():
return self.donation_address
@with_interface_lock
def get_interfaces(self):
'''The interfaces that are in connected state'''
return list(self.interfaces.keys())
@with_recent_servers_lock
def get_servers(self):
out = constants.net.DEFAULT_SERVERS
if self.irc_servers:
@ -376,6 +407,7 @@ class Network(util.DaemonThread):
out[host] = { protocol:port }
return out
@with_interface_lock
def start_interface(self, server):
if (not server in self.interfaces and not server in self.connecting):
if server == self.default_server:
@ -385,7 +417,8 @@ class Network(util.DaemonThread):
c = Connection(server, self.socket_queue, self.config.path)
def start_random_interface(self):
exclude_set = self.disconnected_servers.union(set(self.interfaces))
with self.interface_lock:
exclude_set = self.disconnected_servers.union(set(self.interfaces))
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server:
self.start_interface(server)
@ -433,15 +466,17 @@ class Network(util.DaemonThread):
else:
socket.getaddrinfo = socket._getaddrinfo
@with_interface_lock
def start_network(self, protocol, proxy):
assert not self.interface and not self.interfaces
assert not self.connecting and self.socket_queue.empty()
self.print_error('starting network')
self.disconnected_servers = set([])
self.disconnected_servers = set([]) # note: needs self.interface_lock
self.protocol = protocol
self.set_proxy(proxy)
self.start_interfaces()
@with_interface_lock
def stop_network(self):
self.print_error("stopping network")
for interface in list(self.interfaces.values()):
@ -491,6 +526,7 @@ class Network(util.DaemonThread):
if servers:
self.switch_to_interface(random.choice(servers))
@with_interface_lock
def switch_lagging_interface(self):
'''If auto_connect and lagging, switch interface'''
if self.server_is_lagging() and self.auto_connect:
@ -501,6 +537,7 @@ class Network(util.DaemonThread):
choice = random.choice(filtered)
self.switch_to_interface(choice)
@with_interface_lock
def switch_to_interface(self, server):
'''Switch to server as our interface. If no connection exists nor
being opened, start a thread to connect. The actual switch will
@ -522,6 +559,7 @@ class Network(util.DaemonThread):
self.set_status('connected')
self.notify('updated')
@with_interface_lock
def close_interface(self, interface):
if interface:
if interface.server in self.interfaces:
@ -530,6 +568,7 @@ class Network(util.DaemonThread):
self.interface = None
interface.close()
@with_recent_servers_lock
def add_recent_server(self, server):
# list is ordered
if server in self.recent_servers:
@ -587,7 +626,8 @@ class Network(util.DaemonThread):
for callback in callbacks:
callback(response)
def get_index(self, method, params):
@classmethod
def get_index(cls, method, params):
""" hashable index for subscriptions and cache"""
return str(method) + (':' + str(params[0]) if params else '')
@ -602,12 +642,15 @@ class Network(util.DaemonThread):
# and are placed in the unanswered_requests dictionary
client_req = self.unanswered_requests.pop(message_id, None)
if client_req:
assert interface == self.interface
if interface != self.interface:
# we probably changed the current interface
# in the meantime; drop this.
return
callbacks = [client_req[2]]
else:
# fixme: will only work for subscriptions
k = self.get_index(method, params)
callbacks = self.subscriptions.get(k, [])
callbacks = list(self.subscriptions.get(k, []))
# Copy the request method and params to the response
response['method'] = method
@ -615,7 +658,8 @@ class Network(util.DaemonThread):
# Only once we've received a response to an addr subscription
# add it to the list; avoids double-sends on reconnection
if method == 'blockchain.scripthash.subscribe':
self.subscribed_addresses.add(params[0])
with self.subscribed_addresses_lock:
self.subscribed_addresses.add(params[0])
else:
if not response: # Closed remotely / misbehaving
self.connection_down(interface.server)
@ -630,27 +674,29 @@ class Network(util.DaemonThread):
elif method == 'blockchain.scripthash.subscribe':
response['params'] = [params[0]] # addr
response['result'] = params[1]
callbacks = self.subscriptions.get(k, [])
callbacks = list(self.subscriptions.get(k, []))
# update cache if it's a subscription
if method.endswith('.subscribe'):
self.sub_cache[k] = response
with self.interface_lock:
self.sub_cache[k] = response
# Response is now in canonical form
self.process_response(interface, response, callbacks)
def send(self, messages, callback):
'''Messages is a list of (method, params) tuples'''
messages = list(messages)
with self.lock:
with self.pending_sends_lock:
self.pending_sends.append((messages, callback))
@with_interface_lock
def process_pending_sends(self):
# Requests needs connectivity. If we don't have an interface,
# we cannot process them.
if not self.interface:
return
with self.lock:
with self.pending_sends_lock:
sends = self.pending_sends
self.pending_sends = []
@ -660,10 +706,11 @@ class Network(util.DaemonThread):
if method.endswith('.subscribe'):
k = self.get_index(method, params)
# add callback to list
l = self.subscriptions.get(k, [])
l = list(self.subscriptions.get(k, []))
if callback not in l:
l.append(callback)
self.subscriptions[k] = l
with self.callback_lock:
self.subscriptions[k] = l
# check cached response for subscriptions
r = self.sub_cache.get(k)
@ -679,11 +726,12 @@ class Network(util.DaemonThread):
# Note: we can't unsubscribe from the server, so if we receive
# subsequent notifications process_response() will emit a harmless
# "received unexpected notification" warning
with self.lock:
with self.callback_lock:
for v in self.subscriptions.values():
if callback in v:
v.remove(callback)
@with_interface_lock
def connection_down(self, server):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
@ -693,9 +741,10 @@ class Network(util.DaemonThread):
if server in self.interfaces:
self.close_interface(self.interfaces[server])
self.notify('interfaces')
for b in self.blockchains.values():
if b.catch_up == server:
b.catch_up = None
with self.blockchains_lock:
for b in self.blockchains.values():
if b.catch_up == server:
b.catch_up = None
def new_interface(self, server, socket):
# todo: get tip first, then decide which checkpoint to use.
@ -706,7 +755,8 @@ class Network(util.DaemonThread):
interface.tip = 0
interface.mode = 'default'
interface.request = None
self.interfaces[server] = interface
with self.interface_lock:
self.interfaces[server] = interface
# server.version should be the first message
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
self.queue_request('server.version', params, interface)
@ -729,7 +779,9 @@ class Network(util.DaemonThread):
# Send pings and shut down stale interfaces
# must use copy of values
for interface in list(self.interfaces.values()):
with self.interface_lock:
interfaces = list(self.interfaces.values())
for interface in interfaces:
if interface.has_timed_out():
self.connection_down(interface.server)
elif interface.ping_required():
@ -737,28 +789,30 @@ class Network(util.DaemonThread):
now = time.time()
# nodes
if len(self.interfaces) + len(self.connecting) < self.num_server:
self.start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
self.disconnected_servers = set([])
self.nodes_retry_time = now
with self.interface_lock:
if len(self.interfaces) + len(self.connecting) < self.num_server:
self.start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
self.disconnected_servers = set([])
self.nodes_retry_time = now
# main interface
if not self.is_connected():
if self.auto_connect:
if not self.is_connecting():
self.switch_to_random_interface()
else:
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
self.server_retry_time = now
with self.interface_lock:
if not self.is_connected():
if self.auto_connect:
if not self.is_connecting():
self.switch_to_random_interface()
else:
self.switch_to_interface(self.default_server)
else:
if self.config.is_fee_estimates_update_required():
self.request_fee_estimates()
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
self.server_retry_time = now
else:
self.switch_to_interface(self.default_server)
else:
if self.config.is_fee_estimates_update_required():
self.request_fee_estimates()
def request_chunk(self, interface, index):
if index in self.requested_chunks:
@ -876,7 +930,8 @@ class Network(util.DaemonThread):
if bh > interface.good:
if not interface.blockchain.check_header(interface.bad_header):
b = interface.blockchain.fork(interface.bad_header)
self.blockchains[interface.bad] = b
with self.blockchains_lock:
self.blockchains[interface.bad] = b
interface.blockchain = b
interface.print_error("new chain", b.checkpoint)
interface.mode = 'catch_up'
@ -928,7 +983,9 @@ class Network(util.DaemonThread):
self.notify('interfaces')
def maintain_requests(self):
for interface in list(self.interfaces.values()):
with self.interface_lock:
interfaces = list(self.interfaces.values())
for interface in interfaces:
if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out")
self.connection_down(interface.server)
@ -940,14 +997,14 @@ class Network(util.DaemonThread):
if not self.interfaces:
time.sleep(0.1)
return
rin = [i for i in self.interfaces.values()]
win = [i for i in self.interfaces.values() if i.num_requests()]
with self.interface_lock:
interfaces = list(self.interfaces.values())
rin = [i for i in interfaces]
win = [i for i in interfaces if i.num_requests()]
try:
rout, wout, xout = select.select(rin, win, [], 0.1)
except socket.error as e:
# TODO: py3, get code from e
code = None
if code == errno.EINTR:
if e.errno == errno.EINTR:
return
raise
assert not xout
@ -1004,7 +1061,8 @@ class Network(util.DaemonThread):
self.notify('updated')
self.notify('interfaces')
return
tip = max([x.height() for x in self.blockchains.values()])
with self.blockchains_lock:
tip = max([x.height() for x in self.blockchains.values()])
if tip >=0:
interface.mode = 'backward'
interface.bad = height
@ -1016,19 +1074,24 @@ class Network(util.DaemonThread):
chain.catch_up = interface
interface.mode = 'catch_up'
interface.blockchain = chain
self.print_error("switching to catchup mode", tip, self.blockchains)
with self.blockchains_lock:
self.print_error("switching to catchup mode", tip, self.blockchains)
self.request_header(interface, 0)
else:
self.print_error("chain already catching up with", chain.catch_up.server)
@with_interface_lock
def blockchain(self):
if self.interface and self.interface.blockchain is not None:
self.blockchain_index = self.interface.blockchain.checkpoint
return self.blockchains[self.blockchain_index]
@with_interface_lock
def get_blockchains(self):
out = {}
for k, b in self.blockchains.items():
with self.blockchains_lock:
blockchain_items = list(self.blockchains.items())
for k, b in blockchain_items:
r = list(filter(lambda i: i.blockchain==b, list(self.interfaces.values())))
if r:
out[k] = r
@ -1039,18 +1102,21 @@ class Network(util.DaemonThread):
if blockchain:
self.blockchain_index = index
self.config.set_key('blockchain_index', index)
for i in self.interfaces.values():
with self.interface_lock:
interfaces = list(self.interfaces.values())
for i in interfaces:
if i.blockchain == blockchain:
self.switch_to_interface(i.server)
break
else:
raise Exception('blockchain not found', index)
if self.interface:
server = self.interface.server
host, port, protocol, proxy, auto_connect = self.get_parameters()
host, port, protocol = server.split(':')
self.set_parameters(host, port, protocol, proxy, auto_connect)
with self.interface_lock:
if self.interface:
server = self.interface.server
host, port, protocol, proxy, auto_connect = self.get_parameters()
host, port, protocol = server.split(':')
self.set_parameters(host, port, protocol, proxy, auto_connect)
def get_local_height(self):
return self.blockchain().height()
@ -1189,5 +1255,6 @@ class Network(util.DaemonThread):
with open(path, 'w', encoding='utf-8') as f:
f.write(json.dumps(cp, indent=4))
def max_checkpoint(self):
@classmethod
def max_checkpoint(cls):
return max(0, len(constants.net.CHECKPOINTS) * 2016 - 1)

Loading…
Cancel
Save