import abc
import asyncio
import inspect
import pprint
import textwrap
from types import MethodType
from .log import logger
from .packer import _Packer
from aiozmq import interface
[docs]class Error(Exception):
"""Base RPC exception"""
[docs]class GenericError(Error):
"""Error for all untranslated exceptions from rpc method calls."""
def __init__(self, exc_type, args, exc_repr):
super().__init__(exc_type, args, exc_repr)
self.exc_type = exc_type
self.arguments = args
self.exc_repr = exc_repr
def __repr__(self):
return '<Generic RPC Error {}{}: {}>'.format(self.exc_type,
self.arguments,
self.exc_repr)
[docs]class NotFoundError(Error, LookupError):
"""Error raised by server if RPC namespace/method lookup failed."""
class ParametersError(Error, ValueError):
"""Error raised by server when RPC method's parameters could not
be validated against their annotations."""
[docs]class ServiceClosedError(Error):
"""RPC Service is closed."""
class AbstractHandler(metaclass=abc.ABCMeta):
"""Abstract class for server-side RPC handlers."""
__slots__ = ()
@abc.abstractmethod
def __getitem__(self, key):
raise KeyError
@classmethod
def __subclasshook__(cls, C):
if issubclass(C, (str, bytes)):
return False
if cls is AbstractHandler:
if any("__getitem__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented
[docs]class AttrHandler(AbstractHandler):
"""Base class for RPC handlers via attribute lookup."""
def __getitem__(self, key):
try:
return getattr(self, key)
except AttributeError:
raise KeyError
[docs]def method(func):
"""Marks a decorated function as RPC endpoint handler.
The func object may provide arguments and/or return annotations.
If so annotations should be callable objects and
they will be used to validate received arguments and/or return value.
"""
func.__rpc__ = {}
func.__signature__ = sig = inspect.signature(func)
for name, param in sig.parameters.items():
ann = param.annotation
if ann is not param.empty and not callable(ann):
raise ValueError("Expected {!r} annotation to be callable"
.format(name))
ann = sig.return_annotation
if ann is not sig.empty and not callable(ann):
raise ValueError("Expected return annotation to be callable")
return func
[docs]class Service(asyncio.AbstractServer):
"""RPC service.
Instances of Service (or
descendants) are returned by coroutines that creates clients or
servers.
Implementation of AbstractServer.
"""
def __init__(self, loop, proto):
self._loop = loop
self._proto = proto
@property
def transport(self):
"""Return the transport.
You can use the transport to dynamically bind/unbind,
connect/disconnect etc.
"""
transport = self._proto.transport
if transport is None:
raise ServiceClosedError()
return transport
[docs] def close(self):
if self._proto.closing:
return
self._proto.closing = True
if self._proto.transport is None:
return
self._proto.transport.close()
@asyncio.coroutine
[docs] def wait_closed(self):
if self._proto.transport is None:
return
waiter = asyncio.Future(loop=self._loop)
self._proto.done_waiters.append(waiter)
yield from waiter
class _BaseProtocol(interface.ZmqProtocol):
def __init__(self, loop, *, translation_table=None):
self.loop = loop
self.transport = None
self.done_waiters = []
self.packer = _Packer(translation_table=translation_table)
self.pending_waiters = set()
self.closing = False
def connection_made(self, transport):
self.transport = transport
def connection_lost(self, exc):
self.transport = None
for waiter in self.done_waiters:
waiter.set_result(None)
class _BaseServerProtocol(_BaseProtocol):
def __init__(self, loop, handler, *,
translation_table=None,
log_exceptions=False,
exclude_log_exceptions=(),
timeout=None):
super().__init__(loop, translation_table=translation_table)
if not isinstance(handler, AbstractHandler):
raise TypeError('handler must implement AbstractHandler')
self.handler = handler
self.log_exceptions = log_exceptions
self.exclude_log_exceptions = exclude_log_exceptions
self.timeout = timeout
def connection_lost(self, exc):
super().connection_lost(exc)
for waiter in list(self.pending_waiters):
if not waiter.cancelled():
waiter.cancel()
def dispatch(self, name):
if not name:
raise NotFoundError(name)
namespaces, sep, method = name.rpartition('.')
handler = self.handler
if namespaces:
for part in namespaces.split('.'):
try:
handler = handler[part]
except KeyError:
raise NotFoundError(name)
else:
if not isinstance(handler, AbstractHandler):
raise NotFoundError(name)
try:
func = handler[method]
except KeyError:
raise NotFoundError(name)
else:
if isinstance(func, MethodType):
holder = func.__func__
else:
holder = func
if not hasattr(holder, '__rpc__'):
raise NotFoundError(name)
return func
def check_args(self, func, args, kwargs):
"""Utility function for validating function arguments
Returns validated (args, kwargs, return annotation) tuple
"""
try:
sig = inspect.signature(func)
bargs = sig.bind(*args, **kwargs)
except TypeError as exc:
raise ParametersError(repr(exc)) from exc
else:
arguments = bargs.arguments
marker = object()
for name, param in sig.parameters.items():
if param.annotation is param.empty:
continue
val = arguments.get(name, marker)
if val is marker:
continue # Skip default value
try:
arguments[name] = param.annotation(val)
except (TypeError, ValueError) as exc:
raise ParametersError(
'Invalid value for argument {!r}: {!r}'
.format(name, exc)) from exc
if sig.return_annotation is not sig.empty:
return bargs.args, bargs.kwargs, sig.return_annotation
return bargs.args, bargs.kwargs, None
def try_log(self, fut, name, args, kwargs):
try:
fut.result()
except Exception as exc:
if self.log_exceptions:
for e in self.exclude_log_exceptions:
if isinstance(exc, e):
return
logger.exception(textwrap.dedent("""\
An exception %r from method %r call occurred.
args = %s
kwargs = %s
"""),
exc, name,
pprint.pformat(args), pprint.pformat(kwargs)) # noqa
def add_pending(self, coro):
fut = asyncio.async(coro, loop=self.loop)
self.pending_waiters.add(fut)
return fut
def discard_pending(self, fut):
self.pending_waiters.discard(fut)