# Copyright (c) 2016, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.

'''Class for handling JSON RPC 2.0 connections, server or client.'''

import asyncio
import json
import numbers
import time

from lib.util import LoggedClass


class RequestBase(object):
    '''An object that represents a queued request.'''

    def __init__(self, remaining):
        self.remaining = remaining

class SingleRequest(RequestBase):
    '''An object that represents a single request.'''

    def __init__(self, payload):
        super().__init__(1)
        self.payload = payload

    async def process(self, session):
        '''Asynchronously handle the JSON request.'''
        self.remaining = 0
        binary = await session.process_single_payload(self.payload)
        if binary:
            session._send_bytes(binary)

    def __str__(self):
        return str(self.payload)


class BatchRequest(RequestBase):
    '''An object that represents a batch request and its processing state.

    Batches are processed in chunks.
    '''

    def __init__(self, payload):
        super().__init__(len(payload))
        self.payload = payload
        self.parts = []

    async def process(self, session):
        '''Asynchronously handle the JSON batch according to the JSON 2.0
        spec.'''
        target = max(self.remaining - 4, 0)
        while self.remaining > target:
            item = self.payload[len(self.payload) - self.remaining]
            self.remaining -= 1
            part = await session.process_single_payload(item)
            if part:
                self.parts.append(part)

        total_len = sum(len(part) + 2 for part in self.parts)
        session.check_oversized_request(total_len)

        if not self.remaining:
            if self.parts:
                binary = b'[' + b', '.join(self.parts) + b']'
                session._send_bytes(binary)

    def __str__(self):
        return str(self.payload)


class JSONRPC(asyncio.Protocol, LoggedClass):
    '''Manages a JSONRPC connection.

    Assumes JSON messages are newline-separated and that newlines
    cannot appear in the JSON other than to separate lines.  Incoming
    requests are passed to enqueue_request(), which should arrange for
    their asynchronous handling via the request's process() method.

    Derived classes may want to override connection_made() and
    connection_lost() but should be sure to call the implementation in
    this base class first.  They will also want to implement some or
    all of the asynchronous functions handle_notification(),
    handle_response() and handle_request().

    handle_request() returns the result to pass over the network, and
    must raise an RPCError if there is an error.
    handle_notification() and handle_response() should not return
    anything or raise any exceptions.  All three functions have
    default "ignore" implementations supplied by this class.
    '''

    # See http://www.jsonrpc.org/specification
    PARSE_ERROR = -32700
    INVALID_REQUEST = -32600
    METHOD_NOT_FOUND = -32601
    INVALID_PARAMS = -32602
    INTERNAL_ERROR = -32603

    ID_TYPES = (type(None), str, numbers.Number)
    NEXT_SESSION_ID = 0

    class RPCError(Exception):
        '''RPC handlers raise this error.'''
        def __init__(self, msg, code=-1, **kw_args):
            super().__init__(**kw_args)
            self.msg = msg
            self.code = code

    @classmethod
    def request_payload(cls, method, id_, params=None):
        payload = {'jsonrpc': '2.0', 'id': id_, 'method': method}
        if params:
            payload['params'] = params
        return payload

    @classmethod
    def response_payload(cls, result, id_):
        # We should not respond to notifications
        assert id_ is not None
        return {'jsonrpc': '2.0', 'result': result, 'id': id_}

    @classmethod
    def notification_payload(cls, method, params=None):
        return cls.request_payload(method, None, params)

    @classmethod
    def error_payload(cls, message, code, id_=None):
        error = {'message': message, 'code': code}
        return {'jsonrpc': '2.0', 'error': error, 'id': id_}

    @classmethod
    def payload_id(cls, payload):
        return payload.get('id') if isinstance(payload, dict) else None

    def __init__(self):
        super().__init__()
        self.start = time.time()
        self.stop = 0
        self.last_recv = self.start
        self.bandwidth_start = self.start
        self.bandwidth_interval = 3600
        self.bandwidth_used = 0
        self.bandwidth_limit = 5000000
        self.transport = None
        self.pause = False
        # Parts of an incomplete JSON line.  We buffer them until
        # getting a newline.
        self.parts = []
        # recv_count is JSON messages not calls to data_received()
        self.recv_count = 0
        self.recv_size = 0
        self.send_count = 0
        self.send_size = 0
        self.error_count = 0
        self.peer_info = None
        # Sends longer than max_send are prevented, instead returning
        # an oversized request error to other end of the network
        # connection.  The request causing it is logged.  Values under
        # 1000 are treated as 1000.
        self.max_send = 0
        # If buffered incoming data exceeds this the connection is closed
        self.max_buffer_size = 1000000
        self.anon_logs = False
        self.id_ = JSONRPC.NEXT_SESSION_ID
        JSONRPC.NEXT_SESSION_ID += 1
        self.log_prefix = '[{:d}] '.format(self.id_)
        self.log_me = False

    def peername(self, *, for_log=True):
        '''Return the peer name of this connection.'''
        if not self.peer_info:
            return 'unknown'
        if for_log and self.anon_logs:
            return 'xx.xx.xx.xx:xx'
        return '{}:{}'.format(self.peer_info[0], self.peer_info[1])

    def connection_made(self, transport):
        '''Handle an incoming client connection.'''
        self.transport = transport
        self.peer_info = transport.get_extra_info('peername')
        transport.set_write_buffer_limits(high=500000)

    def connection_lost(self, exc):
        '''Handle client disconnection.'''
        pass

    def pause_writing(self):
        '''Called by asyncio when the write buffer is full.'''
        self.log_info('pausing writing')
        self.pause = True

    def resume_writing(self):
        '''Called by asyncio when the write buffer has room.'''
        self.log_info('resuming writing')
        self.pause = False

    def close_connection(self):
        self.stop = time.time()
        if self.transport:
            self.transport.close()

    def using_bandwidth(self, amount):
        now = time.time()
        # Reduce the recorded usage in proportion to the elapsed time
        elapsed = now - self.bandwidth_start
        self.bandwidth_start = now
        refund = int(elapsed / self.bandwidth_interval * self.bandwidth_limit)
        refund = min(refund, self.bandwidth_used)
        self.bandwidth_used += amount - refund

    def data_received(self, data):
        '''Handle incoming data (synchronously).

        Requests end in newline characters.  Pass complete requests to
        decode_message for handling.
        '''
        self.recv_size += len(data)
        self.using_bandwidth(len(data))

        # Close abusive connections where buffered data exceeds limit
        buffer_size = len(data) + sum(len(part) for part in self.parts)
        if buffer_size > self.max_buffer_size:
            self.log_error('read buffer of {:,d} bytes exceeds {:,d} '
                           'byte limit, closing {}'
                           .format(buffer_size, self.max_buffer_size,
                                   self.peername()))
            self.close_connection()

        # Do nothing if this connection is closing
        if self.transport.is_closing():
            return

        while True:
            npos = data.find(ord('\n'))
            if npos == -1:
                self.parts.append(data)
                break
            self.last_recv = time.time()
            self.recv_count += 1
            tail, data = data[:npos], data[npos + 1:]
            parts, self.parts = self.parts, []
            parts.append(tail)
            self.decode_message(b''.join(parts))

    def decode_message(self, message):
        '''Decode a binary message and queue it for asynchronous handling.

        Messages that cannot be decoded are logged and dropped.
        '''
        try:
            message = message.decode()
        except UnicodeDecodeError as e:
            msg = 'cannot decode binary bytes: {}'.format(e)
            self.send_json_error(msg, self.PARSE_ERROR, close=True)
            return

        try:
            message = json.loads(message)
        except json.JSONDecodeError as e:
            msg = 'cannot decode JSON: {}'.format(e)
            self.send_json_error(msg, self.PARSE_ERROR, close=True)
            return

        if isinstance(message, list):
            # Batches must have at least one request.
            if not message:
                self.send_json_error('empty batch', self.INVALID_REQUEST)
                return
            request = BatchRequest(message)
        else:
            request = SingleRequest(message)

        '''Queue the request for asynchronous handling.'''
        self.enqueue_request(request)
        if self.log_me:
            self.log_info('queued {}'.format(message))

    def encode_payload(self, payload):
        try:
            binary = json.dumps(payload).encode()
        except TypeError:
            msg = 'JSON encoding failure: {}'.format(payload)
            self.log_error(msg)
            return self.json_error(msg, self.INTERNAL_ERROR,
                                   self.payload_id(payload))

        self.check_oversized_request(len(binary))
        self.send_count += 1
        self.send_size += len(binary)
        self.using_bandwidth(len(binary))
        return binary

    def _send_bytes(self, binary, close=False):
        '''Send JSON text over the transport.  Close it if close is True.'''
        # Confirmed this happens, sometimes a lot
        if self.transport.is_closing():
            return
        self.transport.write(binary)
        self.transport.write(b'\n')
        if close or self.error_count > 10:
            self.close_connection()

    def send_json_error(self, message, code, id_=None, close=False):
        '''Send a JSON error and close the connection by default.'''
        self._send_bytes(self.json_error_bytes(message, code, id_), close)

    def encode_and_send_payload(self, payload):
        '''Encode the payload and send it.'''
        self._send_bytes(self.encode_payload(payload))

    def json_notification_bytes(self, method, params):
        '''Return the bytes of a json notification.'''
        return self.encode_payload(self.notification_payload(method, params))

    def json_request_bytes(self, method, id_, params=None):
        '''Return the bytes of a JSON request.'''
        return self.encode_payload(self.request_payload(method, id_, params))

    def json_response_bytes(self, result, id_):
        '''Return the bytes of a JSON response.'''
        return self.encode_payload(self.response_payload(result, id_))

    def json_error_bytes(self, message, code, id_=None):
        '''Return the bytes of a JSON error.'''
        self.error_count += 1
        return self.encode_payload(self.error_payload(message, code, id_))

    async def process_single_payload(self, payload):
        '''Return the binary JSON result of a single JSON request, response or
        notification.

        The result is empty if nothing is to be sent.
        '''

        if not isinstance(payload, dict):
            return self.json_error_bytes('request must be a dict',
                                         self.INVALID_REQUEST)

        try:
            if not 'id' in payload:
                return await self.process_json_notification(payload)

            id_ = payload['id']
            if not isinstance(id_, self.ID_TYPES):
                return self.json_error_bytes('invalid id: {}'.format(id_),
                                             self.INVALID_REQUEST)

            if 'method' in payload:
                return await self.process_json_request(payload)

            return await self.process_json_response(payload)
        except self.RPCError as e:
            return self.json_error_bytes(e.msg, e.code,
                                         self.payload_id(payload))

    @classmethod
    def method_and_params(cls, payload):
        method = payload.get('method')
        params = payload.get('params', [])

        if not isinstance(method, str):
            raise cls.RPCError('invalid method: {}'.format(method),
                               cls.INVALID_REQUEST)

        if not isinstance(params, list):
            raise cls.RPCError('params should be an array',
                               cls.INVALID_REQUEST)

        return method, params

    async def process_json_notification(self, payload):
        try:
            method, params = self.method_and_params(payload)
        except self.RPCError:
            pass
        else:
            await self.handle_notification(method, params)
        return b''

    async def process_json_request(self, payload):
        method, params = self.method_and_params(payload)
        result = await self.handle_request(method, params)
        return self.json_response_bytes(result, payload['id'])

    async def process_json_response(self, payload):
        # Only one of result and error should exist; we go with 'error'
        # if both are supplied.
        if 'error' in payload:
            await self.handle_response(None, payload['error'], payload['id'])
        elif 'result' in payload:
            await self.handle_response(payload['result'], None, payload['id'])
        return b''

    def check_oversized_request(self, total_len):
        if total_len > max(1000, self.max_send):
            raise self.RPCError('request too large', self.INVALID_REQUEST)

    def raise_unknown_method(self, method):
        '''Respond to a request with an unknown method.'''
        raise self.RPCError("unknown method: '{}'".format(method),
                            self.METHOD_NOT_FOUND)

    # Common parameter verification routines
    @classmethod
    def param_to_non_negative_integer(cls, param):
        '''Return param if it is or can be converted to a non-negative
        integer, otherwise raise an RPCError.'''
        try:
            param = int(param)
            if param >= 0:
                return param
        except ValueError:
            pass

        raise cls.RPCError('param {} should be a non-negative integer'
                           .format(param))

    @classmethod
    def params_to_non_negative_integer(cls, params):
        if len(params) == 1:
            return cls.param_to_non_negative_integer(params[0])
        raise cls.RPCError('params {} should contain one non-negative integer'
                            .format(params))

    @classmethod
    def require_empty_params(cls, params):
        if params:
            raise cls.RPCError('params {} should be empty'.format(params))


    # --- derived classes are intended to override these functions
    def enqueue_request(self, request):
        '''Enqueue a request for later asynchronous processing.'''
        raise NotImplementedError

    async def handle_notification(self, method, params):
        '''Handle a notification.'''

    async def handle_request(self, method, params):
        '''Handle a request.'''
        return None

    async def handle_response(self, result, error, id_):
        '''Handle a response.'''