Browse Source

Tweak request handling

Pause serving sessions whose socket buffer is full (anti-DoS)
Serve requests in batches of 8
Don't store the session in the request
RPC has priority 0; every other session at least 1
Periodically consolidate small session groups into 1
master
Neil Booth 8 years ago
parent
commit
263e88ad57
  1. 2
      electrumx_rpc.py
  2. 72
      lib/jsonrpc.py
  3. 75
      server/protocol.py

2
electrumx_rpc.py

@ -43,7 +43,7 @@ class RPCClient(JSONRPC):
future.cancel()
print('request timed out after {}s'.format(timeout))
else:
await request.process(1)
await request.process(self)
async def handle_response(self, result, error, method):
if result and method in ('groups', 'sessions'):

72
lib/jsonrpc.py

@ -15,63 +15,59 @@ import time
from lib.util import LoggedClass
class SingleRequest(object):
class RequestBase(object):
'''An object that represents a queued request.'''
def __init__(self, remaining):
self.remaining = remaining
class SingleRequest(RequestBase):
'''An object that represents a single request.'''
def __init__(self, session, payload):
self.payload = payload
self.session = session
self.count = 1
def remaining(self):
return self.count
def __init__(self, payload):
super().__init__(1)
self.payload = payload
async def process(self, limit):
async def process(self, session):
'''Asynchronously handle the JSON request.'''
binary = await self.session.process_single_payload(self.payload)
self.remaining = 0
binary = await session.process_single_payload(self.payload)
if binary:
self.session._send_bytes(binary)
self.count = 0
return 1
session._send_bytes(binary)
def __str__(self):
return str(self.payload)
class BatchRequest(object):
class BatchRequest(RequestBase):
'''An object that represents a batch request and its processing state.
Batches are processed in chunks.
'''
def __init__(self, session, payload):
self.session = session
def __init__(self, payload):
super().__init__(len(payload))
self.payload = payload
self.done = 0
self.parts = []
def remaining(self):
return len(self.payload) - self.done
async def process(self, limit):
async def process(self, session):
'''Asynchronously handle the JSON batch according to the JSON 2.0
spec.'''
count = min(limit, self.remaining())
for n in range(count):
item = self.payload[self.done]
part = await self.session.process_single_payload(item)
target = max(self.remaining - 4, 0)
while self.remaining > target:
item = self.payload[len(self.payload) - self.remaining]
self.remaining -= 1
part = await session.process_single_payload(item)
if part:
self.parts.append(part)
self.done += 1
total_len = sum(len(part) + 2 for part in self.parts)
self.session.check_oversized_request(total_len)
session.check_oversized_request(total_len)
if not self.remaining():
if not self.remaining:
if self.parts:
binary = b'[' + b', '.join(self.parts) + b']'
self.session._send_bytes(binary)
return count
session._send_bytes(binary)
def __str__(self):
return str(self.payload)
@ -151,6 +147,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
self.bandwidth_used = 0
self.bandwidth_limit = 5000000
self.transport = None
self.pause = False
# Parts of an incomplete JSON line. We buffer them until
# getting a newline.
self.parts = []
@ -186,11 +183,22 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
'''Handle an incoming client connection.'''
self.transport = transport
self.peer_info = transport.get_extra_info('peername')
transport.set_write_buffer_limits(high=500000)
def connection_lost(self, exc):
'''Handle client disconnection.'''
pass
def pause_writing(self):
'''Called by asyncio when the write buffer is full.'''
self.log_info('pausing writing')
self.pause = True
def resume_writing(self):
'''Called by asyncio when the write buffer has room.'''
self.log_info('resuming writing')
self.pause = False
def close_connection(self):
self.stop = time.time()
if self.transport:
@ -263,9 +271,9 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
if not message:
self.send_json_error('empty batch', self.INVALID_REQUEST)
return
request = BatchRequest(self, message)
request = BatchRequest(message)
else:
request = SingleRequest(self, message)
request = SingleRequest(message)
'''Queue the request for asynchronous handling.'''
self.enqueue_request(request)

75
server/protocol.py

@ -21,7 +21,7 @@ from functools import partial
import pylru
from lib.hash import sha256, double_sha256, hash_to_str, hex_str_to_hash
from lib.jsonrpc import JSONRPC
from lib.jsonrpc import JSONRPC, RequestBase
from lib.tx import Deserializer
import lib.util as util
from server.block_processor import BlockProcessor
@ -217,16 +217,15 @@ class ServerManager(util.LoggedClass):
BANDS = 5
class NotificationRequest(object):
def __init__(self, fn_call):
self.fn_call = fn_call
class NotificationRequest(RequestBase):
def __init__(self, height, touched):
super().__init__(1)
self.height = height
self.touched = touched
def remaining(self):
return 0
async def process(self, limit):
await self.fn_call()
return 0
async def process(self, session):
self.remaining = 0
await session.notify(self.height, self.touched)
def __init__(self, env):
super().__init__()
@ -294,8 +293,8 @@ class ServerManager(util.LoggedClass):
if isinstance(session, LocalRPC):
return 0
group_bandwidth = sum(s.bandwidth_used for s in self.sessions[session])
return (bisect_left(self.bands, session.bandwidth_used)
+ bisect_left(self.bands, group_bandwidth) + 1) // 2
return 1 + (bisect_left(self.bands, session.bandwidth_used)
+ bisect_left(self.bands, group_bandwidth) + 1) // 2
async def enqueue_delayed_sessions(self):
now = time.time()
@ -317,9 +316,14 @@ class ServerManager(util.LoggedClass):
item = (priority, self.next_queue_id, session)
self.next_queue_id += 1
secs = priority - self.BANDS
if secs >= 0:
secs = int(session.pause)
if secs:
session.log_info('delaying processing whilst paused')
excess = priority - self.BANDS
if excess > 0:
secs = excess
session.log_info('delaying response {:d}s'.format(secs))
if secs:
self.delayed_sessions.append((time.time() + secs, item))
else:
self.queue.put_nowait(item)
@ -403,8 +407,8 @@ class ServerManager(util.LoggedClass):
for session in self.sessions:
if isinstance(session, ElectrumX):
fn_call = partial(session.notify, self.bp.db_height, touched)
session.enqueue_request(self.NotificationRequest(fn_call))
request = self.NotificationRequest(self.bp.db_height, touched)
session.enqueue_request(request)
# Periodically log sessions
if self.env.log_sessions and time.time() > self.next_log_sessions:
data = self.session_data(for_log=True)
@ -480,7 +484,7 @@ class ServerManager(util.LoggedClass):
if now > self.next_stale_check:
self.next_stale_check = now + 60
self.clear_stale_sessions()
group = self.groups[int(session.start - self.start) // 60]
group = self.groups[int(session.start - self.start) // 180]
group.add(session)
self.sessions[session] = group
session.log_info('connection from {}, {:,d} total'
@ -521,9 +525,14 @@ class ServerManager(util.LoggedClass):
if stale:
self.logger.info('closing stale connections {}'.format(stale))
# Clear out empty groups
for key in [k for k, v in self.groups.items() if not v]:
del self.groups[key]
# Consolidate small groups
keys = [k for k, v in self.groups.items() if len(v) <= 2
and sum(session.bandwidth_used for session in v) < 10000]
if len(keys) > 1:
group = set.union(*(self.groups[key] for key in keys))
for key in keys:
del self.groups[key]
self.groups[max(keys)] = group
def new_subscription(self):
if self.subscription_count >= self.max_subs:
@ -728,7 +737,7 @@ class Session(JSONRPC):
return status
def requests_remaining(self):
return sum(request.remaining() for request in self.requests)
return sum(request.remaining for request in self.requests)
def enqueue_request(self, request):
'''Add a request to the session's list.'''
@ -738,28 +747,28 @@ class Session(JSONRPC):
async def serve_requests(self):
'''Serve requests in batches.'''
done_reqs = 0
done_jobs = 0
limit = 4
total = 0
errs = []
# Process 8 items at a time
for request in self.requests:
try:
done_jobs += await request.process(limit - done_jobs)
initial = request.remaining
await request.process(self)
total += initial - request.remaining
except asyncio.CancelledError:
raise
except Exception:
# Getting here should probably be considered a bug and fixed
# Should probably be considered a bug and fixed
self.log_error('error handling request {}'.format(request))
traceback.print_exc()
done_reqs += 1
else:
if not request.remaining():
done_reqs += 1
if done_jobs >= limit:
errs.append(request)
if total >= 8:
break
self.log_info('done {:,d} items'.format(total))
# Remove completed requests and re-enqueue ourself if any remain.
if done_reqs:
self.requests = self.requests[done_reqs:]
self.requests = [req for req in self.requests
if req.remaining and not req in errs]
if self.requests:
self.manager.enqueue_session(self)

Loading…
Cancel
Save