diff --git a/lib/network.py b/lib/network.py index c5493e306..748381eba 100644 --- a/lib/network.py +++ b/lib/network.py @@ -458,7 +458,7 @@ class Network(util.DaemonThread): self.switch_lagging_interface(i.server) self.notify('updated') - def process_response(self, interface, response): + def process_response(self, interface, response, callbacks): if self.debug: self.print_error("<--", response) error = response.get('error') @@ -490,12 +490,8 @@ class Network(util.DaemonThread): elif method == 'blockchain.block.get_header': self.on_get_header(interface, response) - elif method.endswith('.subscribe'): - k = repr((method, params)) - self.sub_cache[k] = response - callbacks = self.subscriptions.get(k, []) - for callback in callbacks: - callback(response) + for callback in callbacks: + callback(response) def process_responses(self, interface): responses = interface.get_responses() @@ -509,6 +505,9 @@ class Network(util.DaemonThread): client_req = self.unanswered_requests.pop(message_id, None) if client_req: assert interface == self.interface + callbacks = [client_req[2]] + else: + callbacks = [] # Copy the request method and params to the response response['method'] = method response['params'] = params @@ -529,31 +528,17 @@ class Network(util.DaemonThread): elif method == 'blockchain.address.subscribe': response['params'] = [params[0]] # addr response['result'] = params[1] - + callbacks = self.subscriptions.get(repr((method, params)), []) + # update cache if it's a subscription + if method.endswith('.subscribe'): + self.sub_cache[repr((method, params))] = response # Response is now in canonical form - self.process_response(interface, response) + self.process_response(interface, response, callbacks) def send(self, messages, callback): '''Messages is a list of (method, params) tuples''' with self.lock: - subs = filter(lambda (m,v): m.endswith('.subscribe'), messages) - for message in messages: - method, params = message - if method.endswith('.subscribe'): - k = repr((method, params)) - l = self.subscriptions.get(k, []) - if callback not in l: - l.append(callback) - self.subscriptions[k] = l - # check cached response - r = self.sub_cache.get(k) - if r is not None: - util.print_error("cache hit", k) - callback(r) - else: - self.pending_sends.append(message) - else: - self.pending_sends.append(message) + self.pending_sends.append((messages, callback)) def process_pending_sends(self): @@ -566,9 +551,24 @@ class Network(util.DaemonThread): sends = self.pending_sends self.pending_sends = [] - for method, params in sends: - message_id = self.queue_request(method, params) - self.unanswered_requests[message_id] = method, params + for messages, callback in sends: + for method, params in messages: + k = repr((method, params)) + if method.endswith('.subscribe'): + # add callback to list + l = self.subscriptions.get(k, []) + if callback not in l: + l.append(callback) + self.subscriptions[k] = l + + # check cached response for subscriptions + r = self.sub_cache.get(k) + if r is not None: + util.print_error("cache hit", k) + callback(r) + else: + message_id = self.queue_request(method, params) + self.unanswered_requests[message_id] = method, params, callback def unsubscribe(self, callback): '''Unsubscribe a callback to free object references to enable GC.'''