# Copyright (C) 2012- Takafumi Arakaki # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see . import sys import itertools import threading from sexpdata import loads, dumps, Symbol, String from .py3compat import SocketServer, Queue from .utils import autolog, LockingDict, newthread, callwith class BaseRemoteError(Exception): """ All exceptions from remote method are derived from this class. """ class CallerUnknown(BaseRemoteError): """ Error raised in remote method, but caller of the method is unknown. """ class EPCError(BaseRemoteError): """ Error returned by `epc-error` protocol. """ class ReturnError(BaseRemoteError): """ Error returned by `return-error` protocol. """ class EPCErrorCallerUnknown(CallerUnknown, EPCError): """ Same as :class:`EPCError`, but caller is unknown. """ class ReturnErrorCallerUnknown(CallerUnknown, ReturnError): """ Same as :class:`ReturnError`, but caller is unknown. """ class EPCClosed(Exception): """ Trying to send to a closed socket. """ def encode_string(string): data = string.encode('utf-8') datalen = '{0:06x}'.format(len(data) + 1).encode() return _JOIN_BYTES([datalen, data, _NEWLINE_BYTE]) _JOIN_BYTES = ''.encode().join _NEWLINE_BYTE = '\n'.encode() def encode_object(obj, **kwds): return encode_string(dumps(obj, **kwds)) def encode_message(name, *args, **kwds): return encode_object([Symbol(name)] + list(args), **kwds) def unpack_message(bytes): data = loads(bytes.decode('utf-8')) return (data[0].value(), data[1], data[2:]) def itermessage(read): while True: head = read(6) if not head: return length = int(head, 16) data = read(length) if len(data) < length: raise ValueError('need {0}-length data; got {1}' .format(length, len(data))) yield data class BlockingCallback(object): def __init__(self): self.queue = q = Queue.Queue() self.callback = lambda x: q.put(('return', x)) self.errback = lambda x: q.put(('error', x)) self.cbs = {'callback': self.callback, 'errback': self.errback} def result(self, timeout): (rtype, reply) = self.queue.get(timeout=timeout) if rtype == 'return': return reply else: raise reply class EPCCallManager: Dict = LockingDict # FIXME: make it configurable from server class. """ Dictionary class used to store callbacks. """ def __init__(self): self.callbacks = self.Dict() counter = itertools.count(1) self.get_uid = callwith(threading.Lock())(lambda: next(counter)) # Wrapping by threading.Lock is useless for non-threading # handler. Probably it is better to make it optional. def call(self, handler, name, args=[], callback=None, errback=None): uid = self.get_uid() self.callbacks[uid] = (callback, errback) handler._send('call', uid, Symbol(name), args) def methods(self, handler, callback=None, errback=None): uid = self.get_uid() self.callbacks[uid] = (callback, errback) handler._send('methods', uid) def handle_return(self, uid, reply): try: (callback, _) = self.callbacks.pop(uid) except (KeyError, TypeError): raise CallerUnknown(reply) if callback is not None: callback(reply) def _handle_error_reply(self, uid, reply, eclass, notfound): try: (_, errback) = self.callbacks.pop(uid) except (KeyError, TypeError): raise notfound(reply) error = eclass(reply) if errback is None: raise error else: errback(error) def handle_return_error(self, uid, reply): self._handle_error_reply(uid, reply, ReturnError, ReturnErrorCallerUnknown) def handle_epc_error(self, uid, reply): self._handle_error_reply(uid, reply, EPCError, EPCErrorCallerUnknown) class EPCHandler(SocketServer.StreamRequestHandler): # These attribute are defined in `SocketServer.BaseRequestHandler` # self.server : an instance of `EPCServer` # self.request : # self.client_address # These attribute are defined in `SocketServer.StreamRequestHandler` # self.connection : = self.request # self.rfile : stream from client # self.wfile : stream to client @property def logger(self): return self.server.logger @autolog('debug') def setup(self): SocketServer.StreamRequestHandler.setup(self) self.callmanager = EPCCallManager() self.server.add_client(self) @autolog('debug') def finish(self): try: SocketServer.StreamRequestHandler.finish(self) finally: self.server.remove_client(self) def _rfile_read_safely(self, size): try: return self.rfile.read(size) except (AttributeError, ValueError): if self.rfile.closed: # Calling read on closed socket raises # AttributeError in 2.x and ValueError in 3.x. # http://bugs.python.org/issue9177 raise StopIteration else: raise # if not, just re-raise it. def _recv(self): self.logger.debug('receiving...') for data in itermessage(self._rfile_read_safely): self.logger.debug( 'received: length = %r; data = %r', len(data), data) yield data self.logger.debug('receiving...') @autolog('debug') def _send(self, *args): string = encode_message(*args) try: self.wfile.write(string) except (AttributeError, ValueError): # See also: :meth:`_rfile_read_safely` raise EPCClosed @autolog('debug') def handle(self): for sexp in self._recv(): self._handle(sexp) @autolog('debug') def _handle(self, sexp): uid = undefined = [] # default: nil try: (name, uid, args) = unpack_message(sexp) pyname = name.replace('-', '_') getattr(self, '_validate_{0}'.format(pyname))(uid, args) handler = getattr(self, '_handle_{0}'.format(pyname)) reply = handler(uid, *args) if reply is not None: self._send(*reply) except Exception as err: if self.handle_error(err): return if self.server.debugger or self.server.log_traceback: exc_info = sys.exc_info() self.logger.error('Unexpected error', exc_info=exc_info) if self.server.debugger: self.server.debugger.post_mortem(exc_info[2]) name = 'epc-error' if uid is undefined else 'return-error' self._send(name, uid, repr(err)) @autolog('debug') def _handle_call(self, uid, meth, args): # See: `epc:handler-called-method` name = meth.value() try: func = self.server.get_method(name) except AttributeError: return ['epc-error', uid, "EPC-ERROR: No such method : {0}".format(name)] return ['return', uid, func(*args)] def _handle_methods(self, uid): return ['return', uid, [ (Symbol(name), [], String(func.__doc__ or "")) # FIXNE: implement arg-specs for (name, func) in self.server.funcs.items()]] def _handle_return(self, uid, reply): self.callmanager.handle_return(uid, reply) def _handle_return_error(self, uid, reply=None, *_): self.callmanager.handle_return_error(uid, reply) def _handle_epc_error(self, uid, reply=None, *_): self.callmanager.handle_epc_error(uid, reply) _epc_error_template = \ "(%s %d ...): Got %s arguments in the reply: %r" def _validate_call(self, uid, args, num_expect=2, name='call'): len_args = len(args) if len_args == num_expect: return elif len_args < num_expect: message = 'Not enough arguments {0!r}'.format(args) else: message = 'Too many arguments {0!r}'.format(args) self._send("epc-error", uid, message) raise EPCError('({0} {1} ...): {2}'.format(name, uid, message)) def _validate_methods(self, uid, args): self._validate_call(uid, args, 0, 'methods') def _validate_return(self, uid, args): len_args = len(args) error = lambda x: self._epc_error_template % ('return', uid, x, args) if len_args == 0: message = error('not enough') elif len_args > 1: message = error('too many') else: return self.logger.error(message) self._handle_epc_error(uid, message) raise EPCError(message) def _validate_return_error(self, uid, args): self._log_extra_argument_error('return-error', uid, args) def _validate_epc_error(self, uid, args): self._log_extra_argument_error('epc-error', uid, args) def _log_extra_argument_error(self, name, uid, args): if len(args) > 1: self.logger.error(self._epc_error_template, 'return-error', uid, 'too many', args) def handle_error(self, err): """ Handle error which is not handled by errback. :type err: Exception :arg err: An error not handled by other mechanisms. :rtype: boolean Return True from this function means that error is properly handled, so the error is not sent to client. Do not confuse this with :meth:`SocketServer.BaseServer.handle_error`. This method is for handling error for each client, not for entire server. Default implementation logs the error and returns True if the error is coming from remote [#]_ or returns False otherwise. Therefore, only the error occurs in this handler class is sent to remote. .. [#] More specifically, it returns True if `err` is an instance of :class:`BaseRemoteError` or :class:`EPCClosed`. """ self.logger.error(repr(err)) if isinstance(err, (BaseRemoteError, EPCClosed)): # BaseRemoteError: do not send error back # EPCClosed: no exception from thread return True def call(self, name, *args, **kwds): """ Call method connected to this handler. :type name: str :arg name: Method name to call. :type args: list :arg args: Arguments for remote method to call. :type callback: callable :arg callback: A function to be called with returned value of the remote method. :type errback: callable :arg errback: A function to be called with an error occurred in the remote method. It is either an instance of :class:`ReturnError` or :class:`EPCError`. """ self.callmanager.call(self, name, *args, **kwds) def methods(self, *args, **kwds): """ Request info of callable remote methods. Arguments for :meth:`call` except for `name` can be applied to this function too. """ self.callmanager.methods(self, *args, **kwds) @staticmethod def _blocking_request(call, timeout, *args): bc = BlockingCallback() call(*args, **bc.cbs) return bc.result(timeout=timeout) def call_sync(self, name, args, timeout=None): """ Blocking version of :meth:`call`. :type name: str :arg name: Remote function name to call. :type args: list :arg args: Arguments passed to the remote function. :type timeout: int or None :arg timeout: Timeout in second. None means no timeout. If the called remote function raise an exception, this method raise an exception. If you give `timeout`, this method may raise an `Empty` exception. """ return self._blocking_request(self.call, timeout, name, args) def methods_sync(self, timeout=None): """ Blocking version of :meth:`methods`. See also :meth:`call_sync`. """ return self._blocking_request(self.methods, timeout) class ThreadingEPCHandler(EPCHandler): def _handle(self, sexp): newthread(self, target=EPCHandler._handle, args=(self, sexp)).start()