"""ZeroMQ RPC"""
import asyncio
import os
import random
import struct
import sys
import time
from collections import ChainMap
from functools import partial
import zmq
from aiozmq import create_zmq_connection
from .base import (
GenericError,
NotFoundError,
ParametersError,
Service,
ServiceClosedError,
_BaseProtocol,
_BaseServerProtocol,
)
from .log import logger
from .util import (
_MethodCall,
_fill_error_table,
)
__all__ = [
'connect_rpc',
'serve_rpc',
]
@asyncio.coroutine
[docs]def connect_rpc(*, connect=None, bind=None, loop=None,
error_table=None, translation_table=None, timeout=None):
"""A coroutine that creates and connects/binds RPC client.
Usually for this function you need to use *connect* parameter, but
ZeroMQ does not forbid to use *bind*.
error_table -- an optional table for custom exception translators.
timeout -- an optional timeout for RPC calls. If timeout is not
None and remote call takes longer than timeout seconds then
asyncio.TimeoutError will be raised at client side. If the server
will return an answer after timeout has been raised that answer
**is ignored**.
translation_table -- an optional table for custom value translators.
loop -- an optional parameter to point ZmqEventLoop instance. If
loop is None then default event loop will be given by
asyncio.get_event_loop call.
Returns a RPCClient instance.
"""
if loop is None:
loop = asyncio.get_event_loop()
transp, proto = yield from create_zmq_connection(
lambda: _ClientProtocol(loop, error_table=error_table,
translation_table=translation_table),
zmq.DEALER, connect=connect, bind=bind, loop=loop)
return RPCClient(loop, proto, timeout=timeout)
@asyncio.coroutine
[docs]def serve_rpc(handler, *, connect=None, bind=None, loop=None,
translation_table=None, log_exceptions=False,
exclude_log_exceptions=(), timeout=None):
"""A coroutine that creates and connects/binds RPC server instance.
Usually for this function you need to use *bind* parameter, but
ZeroMQ does not forbid to use *connect*.
handler -- an object which processes incoming RPC calls. Usually
you like to pass AttrHandler instance.
log_exceptions -- log exceptions from remote calls if True.
exclude_log_exceptions -- sequence of exception classes than should not
be logged.
translation_table -- an optional table for custom value translators.
timeout -- timeout for performing handling of async server calls.
loop -- an optional parameter to point ZmqEventLoop instance. If
loop is None then default event loop will be given by
asyncio.get_event_loop call.
Returns Service instance.
"""
if loop is None:
loop = asyncio.get_event_loop()
transp, proto = yield from create_zmq_connection(
lambda: _ServerProtocol(loop, handler,
translation_table=translation_table,
log_exceptions=log_exceptions,
exclude_log_exceptions=exclude_log_exceptions,
timeout=timeout),
zmq.ROUTER, connect=connect, bind=bind, loop=loop)
return Service(loop, proto)
_default_error_table = _fill_error_table()
class _ClientProtocol(_BaseProtocol):
"""Client protocol implementation."""
REQ_PREFIX = struct.Struct('=HH')
REQ_SUFFIX = struct.Struct('=Ld')
RESP = struct.Struct('=HHLd?')
def __init__(self, loop, *, error_table=None, translation_table=None):
super().__init__(loop, translation_table=translation_table)
self.calls = {}
self.prefix = self.REQ_PREFIX.pack(os.getpid() % 0x10000,
random.randrange(0x10000))
self.counter = 0
if error_table is None:
self.error_table = _default_error_table
else:
self.error_table = ChainMap(error_table, _default_error_table)
def msg_received(self, data):
try:
header, banswer = data
pid, rnd, req_id, timestamp, is_error = self.RESP.unpack(header)
answer = self.packer.unpackb(banswer)
except Exception:
logger.critical("Cannot unpack %r", data, exc_info=sys.exc_info())
return
call = self.calls.pop(req_id, None)
if call is None:
logger.critical("Unknown answer id: %d (%d %d %f %d) -> %s",
req_id, pid, rnd, timestamp, is_error, answer)
elif call.cancelled():
logger.debug("The future for request #%08x has been cancelled, "
"skip the received result.", req_id)
else:
if is_error:
call.set_exception(self._translate_error(*answer))
else:
call.set_result(answer)
def connection_lost(self, exc):
super().connection_lost(exc)
for call in self.calls.values():
if not call.cancelled():
call.cancel()
def _translate_error(self, exc_type, exc_args, exc_repr):
found = self.error_table.get(exc_type)
if found is None:
return GenericError(exc_type, exc_args, exc_repr)
else:
return found(*exc_args)
def _new_id(self):
self.counter += 1
if self.counter > 0xffffffff:
self.counter = 0
return (self.prefix + self.REQ_SUFFIX.pack(self.counter, time.time()),
self.counter)
def call(self, name, args, kwargs):
if self.transport is None:
raise ServiceClosedError()
bname = name.encode('utf-8')
bargs = self.packer.packb(args)
bkwargs = self.packer.packb(kwargs)
header, req_id = self._new_id()
assert req_id not in self.calls, (req_id, self.calls)
fut = asyncio.Future(loop=self.loop)
self.calls[req_id] = fut
self.transport.write([header, bname, bargs, bkwargs])
return fut
class RPCClient(Service):
def __init__(self, loop, proto, *, timeout):
super().__init__(loop, proto)
self._timeout = timeout
@property
def call(self):
"""Return object for dynamic RPC calls.
The usage is:
ret = yield from client.call.ns.func(1, 2)
"""
return _MethodCall(self._proto, timeout=self._timeout)
def with_timeout(self, timeout):
"""Return a new RPCClient instance with overriden timeout"""
return self.__class__(self._loop, self._proto, timeout=timeout)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
return
class _ServerProtocol(_BaseServerProtocol):
REQ = struct.Struct('=HHLd')
RESP_PREFIX = struct.Struct('=HH')
RESP_SUFFIX = struct.Struct('=Ld?')
def __init__(self, loop, handler, *,
translation_table=None,
log_exceptions=False,
exclude_log_exceptions=(),
timeout=None):
super().__init__(loop, handler,
translation_table=translation_table,
log_exceptions=log_exceptions,
exclude_log_exceptions=exclude_log_exceptions,
timeout=timeout)
self.prefix = self.RESP_PREFIX.pack(os.getpid() % 0x10000,
random.randrange(0x10000))
def msg_received(self, data):
try:
*pre, header, bname, bargs, bkwargs = data
pid, rnd, req_id, timestamp = self.REQ.unpack(header)
name = bname.decode('utf-8')
args = self.packer.unpackb(bargs)
kwargs = self.packer.unpackb(bkwargs)
except Exception as exc:
logger.critical("Cannot unpack %r", data, exc_info=sys.exc_info())
return
try:
func = self.dispatch(name)
args, kwargs, ret_ann = self.check_args(func, args, kwargs)
except (NotFoundError, ParametersError) as exc:
fut = asyncio.Future(loop=self.loop)
fut.add_done_callback(partial(self.process_call_result,
req_id=req_id, pre=pre,
name=name, args=args, kwargs=kwargs))
fut.set_exception(exc)
else:
if asyncio.iscoroutinefunction(func):
fut = self.add_pending(func(*args, **kwargs))
else:
fut = asyncio.Future(loop=self.loop)
try:
fut.set_result(func(*args, **kwargs))
except Exception as exc:
fut.set_exception(exc)
fut.add_done_callback(partial(self.process_call_result,
req_id=req_id, pre=pre,
return_annotation=ret_ann,
name=name,
args=args,
kwargs=kwargs))
def process_call_result(self, fut, *, req_id, pre, name,
args, kwargs,
return_annotation=None):
self.discard_pending(fut)
self.try_log(fut, name, args, kwargs)
if self.transport is None:
return
try:
ret = fut.result()
if return_annotation is not None:
ret = return_annotation(ret)
prefix = self.prefix + self.RESP_SUFFIX.pack(req_id,
time.time(), False)
self.transport.write(pre + [prefix, self.packer.packb(ret)])
except asyncio.CancelledError:
return
except Exception as exc:
prefix = self.prefix + self.RESP_SUFFIX.pack(req_id,
time.time(), True)
exc_type = exc.__class__
exc_info = (exc_type.__module__ + '.' + exc_type.__qualname__,
exc.args, repr(exc))
self.transport.write(pre + [prefix, self.packer.packb(exc_info)])