2020-04-08 10:38:14 -03:00

415 lines
13 KiB
Python

# Copyright (C) 2012- Takafumi Arakaki
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys
import itertools
import threading
from sexpdata import loads, dumps, Symbol, String
from .py3compat import SocketServer, Queue
from .utils import autolog, LockingDict, newthread, callwith
class BaseRemoteError(Exception):
"""
All exceptions from remote method are derived from this class.
"""
class CallerUnknown(BaseRemoteError):
"""
Error raised in remote method, but caller of the method is unknown.
"""
class EPCError(BaseRemoteError):
"""
Error returned by `epc-error` protocol.
"""
class ReturnError(BaseRemoteError):
"""
Error returned by `return-error` protocol.
"""
class EPCErrorCallerUnknown(CallerUnknown, EPCError):
"""
Same as :class:`EPCError`, but caller is unknown.
"""
class ReturnErrorCallerUnknown(CallerUnknown, ReturnError):
"""
Same as :class:`ReturnError`, but caller is unknown.
"""
class EPCClosed(Exception):
"""
Trying to send to a closed socket.
"""
def encode_string(string):
data = string.encode('utf-8')
datalen = '{0:06x}'.format(len(data) + 1).encode()
return _JOIN_BYTES([datalen, data, _NEWLINE_BYTE])
_JOIN_BYTES = ''.encode().join
_NEWLINE_BYTE = '\n'.encode()
def encode_object(obj, **kwds):
return encode_string(dumps(obj, **kwds))
def encode_message(name, *args, **kwds):
return encode_object([Symbol(name)] + list(args), **kwds)
def unpack_message(bytes):
data = loads(bytes.decode('utf-8'))
return (data[0].value(), data[1], data[2:])
def itermessage(read):
while True:
head = read(6)
if not head:
return
length = int(head, 16)
data = read(length)
if len(data) < length:
raise ValueError('need {0}-length data; got {1}'
.format(length, len(data)))
yield data
class BlockingCallback(object):
def __init__(self):
self.queue = q = Queue.Queue()
self.callback = lambda x: q.put(('return', x))
self.errback = lambda x: q.put(('error', x))
self.cbs = {'callback': self.callback, 'errback': self.errback}
def result(self, timeout):
(rtype, reply) = self.queue.get(timeout=timeout)
if rtype == 'return':
return reply
else:
raise reply
class EPCCallManager:
Dict = LockingDict # FIXME: make it configurable from server class.
"""
Dictionary class used to store callbacks.
"""
def __init__(self):
self.callbacks = self.Dict()
counter = itertools.count(1)
self.get_uid = callwith(threading.Lock())(lambda: next(counter))
# Wrapping by threading.Lock is useless for non-threading
# handler. Probably it is better to make it optional.
def call(self, handler, name, args=[], callback=None, errback=None):
uid = self.get_uid()
self.callbacks[uid] = (callback, errback)
handler._send('call', uid, Symbol(name), args)
def methods(self, handler, callback=None, errback=None):
uid = self.get_uid()
self.callbacks[uid] = (callback, errback)
handler._send('methods', uid)
def handle_return(self, uid, reply):
try:
(callback, _) = self.callbacks.pop(uid)
except (KeyError, TypeError):
raise CallerUnknown(reply)
if callback is not None:
callback(reply)
def _handle_error_reply(self, uid, reply, eclass, notfound):
try:
(_, errback) = self.callbacks.pop(uid)
except (KeyError, TypeError):
raise notfound(reply)
error = eclass(reply)
if errback is None:
raise error
else:
errback(error)
def handle_return_error(self, uid, reply):
self._handle_error_reply(uid, reply, ReturnError,
ReturnErrorCallerUnknown)
def handle_epc_error(self, uid, reply):
self._handle_error_reply(uid, reply, EPCError,
EPCErrorCallerUnknown)
class EPCHandler(SocketServer.StreamRequestHandler):
# These attribute are defined in `SocketServer.BaseRequestHandler`
# self.server : an instance of `EPCServer`
# self.request :
# self.client_address
# These attribute are defined in `SocketServer.StreamRequestHandler`
# self.connection : = self.request
# self.rfile : stream from client
# self.wfile : stream to client
@property
def logger(self):
return self.server.logger
@autolog('debug')
def setup(self):
SocketServer.StreamRequestHandler.setup(self)
self.callmanager = EPCCallManager()
self.server.add_client(self)
@autolog('debug')
def finish(self):
try:
SocketServer.StreamRequestHandler.finish(self)
finally:
self.server.remove_client(self)
def _rfile_read_safely(self, size):
try:
return self.rfile.read(size)
except (AttributeError, ValueError):
if self.rfile.closed:
# Calling read on closed socket raises
# AttributeError in 2.x and ValueError in 3.x.
# http://bugs.python.org/issue9177
raise StopIteration
else:
raise # if not, just re-raise it.
def _recv(self):
self.logger.debug('receiving...')
for data in itermessage(self._rfile_read_safely):
self.logger.debug(
'received: length = %r; data = %r', len(data), data)
yield data
self.logger.debug('receiving...')
@autolog('debug')
def _send(self, *args):
string = encode_message(*args)
try:
self.wfile.write(string)
except (AttributeError, ValueError):
# See also: :meth:`_rfile_read_safely`
raise EPCClosed
@autolog('debug')
def handle(self):
for sexp in self._recv():
self._handle(sexp)
@autolog('debug')
def _handle(self, sexp):
uid = undefined = [] # default: nil
try:
(name, uid, args) = unpack_message(sexp)
pyname = name.replace('-', '_')
getattr(self, '_validate_{0}'.format(pyname))(uid, args)
handler = getattr(self, '_handle_{0}'.format(pyname))
reply = handler(uid, *args)
if reply is not None:
self._send(*reply)
except Exception as err:
if self.handle_error(err):
return
if self.server.debugger or self.server.log_traceback:
exc_info = sys.exc_info()
self.logger.error('Unexpected error', exc_info=exc_info)
if self.server.debugger:
self.server.debugger.post_mortem(exc_info[2])
name = 'epc-error' if uid is undefined else 'return-error'
self._send(name, uid, repr(err))
@autolog('debug')
def _handle_call(self, uid, meth, args):
# See: `epc:handler-called-method`
name = meth.value()
try:
func = self.server.get_method(name)
except AttributeError:
return ['epc-error', uid,
"EPC-ERROR: No such method : {0}".format(name)]
return ['return', uid, func(*args)]
def _handle_methods(self, uid):
return ['return', uid, [
(Symbol(name), [], String(func.__doc__ or ""))
# FIXNE: implement arg-specs
for (name, func)
in self.server.funcs.items()]]
def _handle_return(self, uid, reply):
self.callmanager.handle_return(uid, reply)
def _handle_return_error(self, uid, reply=None, *_):
self.callmanager.handle_return_error(uid, reply)
def _handle_epc_error(self, uid, reply=None, *_):
self.callmanager.handle_epc_error(uid, reply)
_epc_error_template = \
"(%s %d ...): Got %s arguments in the reply: %r"
def _validate_call(self, uid, args, num_expect=2, name='call'):
len_args = len(args)
if len_args == num_expect:
return
elif len_args < num_expect:
message = 'Not enough arguments {0!r}'.format(args)
else:
message = 'Too many arguments {0!r}'.format(args)
self._send("epc-error", uid, message)
raise EPCError('({0} {1} ...): {2}'.format(name, uid, message))
def _validate_methods(self, uid, args):
self._validate_call(uid, args, 0, 'methods')
def _validate_return(self, uid, args):
len_args = len(args)
error = lambda x: self._epc_error_template % ('return', uid, x, args)
if len_args == 0:
message = error('not enough')
elif len_args > 1:
message = error('too many')
else:
return
self.logger.error(message)
self._handle_epc_error(uid, message)
raise EPCError(message)
def _validate_return_error(self, uid, args):
self._log_extra_argument_error('return-error', uid, args)
def _validate_epc_error(self, uid, args):
self._log_extra_argument_error('epc-error', uid, args)
def _log_extra_argument_error(self, name, uid, args):
if len(args) > 1:
self.logger.error(self._epc_error_template,
'return-error', uid, 'too many', args)
def handle_error(self, err):
"""
Handle error which is not handled by errback.
:type err: Exception
:arg err: An error not handled by other mechanisms.
:rtype: boolean
Return True from this function means that error is properly
handled, so the error is not sent to client. Do not confuse
this with :meth:`SocketServer.BaseServer.handle_error`. This
method is for handling error for each client, not for entire
server. Default implementation logs the error and returns
True if the error is coming from remote [#]_ or returns False
otherwise. Therefore, only the error occurs in this handler
class is sent to remote.
.. [#] More specifically, it returns True if `err` is an
instance of :class:`BaseRemoteError` or :class:`EPCClosed`.
"""
self.logger.error(repr(err))
if isinstance(err, (BaseRemoteError, EPCClosed)):
# BaseRemoteError: do not send error back
# EPCClosed: no exception from thread
return True
def call(self, name, *args, **kwds):
"""
Call method connected to this handler.
:type name: str
:arg name: Method name to call.
:type args: list
:arg args: Arguments for remote method to call.
:type callback: callable
:arg callback: A function to be called with returned value of
the remote method.
:type errback: callable
:arg errback: A function to be called with an error occurred
in the remote method. It is either an instance
of :class:`ReturnError` or :class:`EPCError`.
"""
self.callmanager.call(self, name, *args, **kwds)
def methods(self, *args, **kwds):
"""
Request info of callable remote methods.
Arguments for :meth:`call` except for `name` can be applied to
this function too.
"""
self.callmanager.methods(self, *args, **kwds)
@staticmethod
def _blocking_request(call, timeout, *args):
bc = BlockingCallback()
call(*args, **bc.cbs)
return bc.result(timeout=timeout)
def call_sync(self, name, args, timeout=None):
"""
Blocking version of :meth:`call`.
:type name: str
:arg name: Remote function name to call.
:type args: list
:arg args: Arguments passed to the remote function.
:type timeout: int or None
:arg timeout: Timeout in second. None means no timeout.
If the called remote function raise an exception, this method
raise an exception. If you give `timeout`, this method may
raise an `Empty` exception.
"""
return self._blocking_request(self.call, timeout, name, args)
def methods_sync(self, timeout=None):
"""
Blocking version of :meth:`methods`. See also :meth:`call_sync`.
"""
return self._blocking_request(self.methods, timeout)
class ThreadingEPCHandler(EPCHandler):
def _handle(self, sexp):
newthread(self, target=EPCHandler._handle, args=(self, sexp)).start()