|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Invocation-side implementation of gRPC Asyncio Python.""" |
|
|
|
import asyncio |
|
import sys |
|
from typing import Any, Iterable, List, Optional, Sequence |
|
|
|
import grpc |
|
from grpc import _common |
|
from grpc import _compression |
|
from grpc import _grpcio_metadata |
|
from grpc._cython import cygrpc |
|
|
|
from . import _base_call |
|
from . import _base_channel |
|
from ._call import StreamStreamCall |
|
from ._call import StreamUnaryCall |
|
from ._call import UnaryStreamCall |
|
from ._call import UnaryUnaryCall |
|
from ._interceptor import ClientInterceptor |
|
from ._interceptor import InterceptedStreamStreamCall |
|
from ._interceptor import InterceptedStreamUnaryCall |
|
from ._interceptor import InterceptedUnaryStreamCall |
|
from ._interceptor import InterceptedUnaryUnaryCall |
|
from ._interceptor import StreamStreamClientInterceptor |
|
from ._interceptor import StreamUnaryClientInterceptor |
|
from ._interceptor import UnaryStreamClientInterceptor |
|
from ._interceptor import UnaryUnaryClientInterceptor |
|
from ._metadata import Metadata |
|
from ._typing import ChannelArgumentType |
|
from ._typing import DeserializingFunction |
|
from ._typing import RequestIterableType |
|
from ._typing import SerializingFunction |
|
from ._utils import _timeout_to_deadline |
|
|
|
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) |
|
|
|
if sys.version_info[1] < 7: |
|
|
|
def _all_tasks() -> Iterable[asyncio.Task]: |
|
return asyncio.Task.all_tasks() |
|
else: |
|
|
|
def _all_tasks() -> Iterable[asyncio.Task]: |
|
return asyncio.all_tasks() |
|
|
|
|
|
def _augment_channel_arguments(base_options: ChannelArgumentType, |
|
compression: Optional[grpc.Compression]): |
|
compression_channel_argument = _compression.create_channel_option( |
|
compression) |
|
user_agent_channel_argument = (( |
|
cygrpc.ChannelArgKey.primary_user_agent_string, |
|
_USER_AGENT, |
|
),) |
|
return tuple(base_options |
|
) + compression_channel_argument + user_agent_channel_argument |
|
|
|
|
|
class _BaseMultiCallable: |
|
"""Base class of all multi callable objects. |
|
|
|
Handles the initialization logic and stores common attributes. |
|
""" |
|
_loop: asyncio.AbstractEventLoop |
|
_channel: cygrpc.AioChannel |
|
_method: bytes |
|
_request_serializer: SerializingFunction |
|
_response_deserializer: DeserializingFunction |
|
_interceptors: Optional[Sequence[ClientInterceptor]] |
|
_references: List[Any] |
|
_loop: asyncio.AbstractEventLoop |
|
|
|
|
|
def __init__( |
|
self, |
|
channel: cygrpc.AioChannel, |
|
method: bytes, |
|
request_serializer: SerializingFunction, |
|
response_deserializer: DeserializingFunction, |
|
interceptors: Optional[Sequence[ClientInterceptor]], |
|
references: List[Any], |
|
loop: asyncio.AbstractEventLoop, |
|
) -> None: |
|
self._loop = loop |
|
self._channel = channel |
|
self._method = method |
|
self._request_serializer = request_serializer |
|
self._response_deserializer = response_deserializer |
|
self._interceptors = interceptors |
|
self._references = references |
|
|
|
@staticmethod |
|
def _init_metadata( |
|
metadata: Optional[Metadata] = None, |
|
compression: Optional[grpc.Compression] = None) -> Metadata: |
|
"""Based on the provided values for <metadata> or <compression> initialise the final |
|
metadata, as it should be used for the current call. |
|
""" |
|
metadata = metadata or Metadata() |
|
if compression: |
|
metadata = Metadata( |
|
*_compression.augment_metadata(metadata, compression)) |
|
return metadata |
|
|
|
|
|
class UnaryUnaryMultiCallable(_BaseMultiCallable, |
|
_base_channel.UnaryUnaryMultiCallable): |
|
|
|
def __call__( |
|
self, |
|
request: Any, |
|
*, |
|
timeout: Optional[float] = None, |
|
metadata: Optional[Metadata] = None, |
|
credentials: Optional[grpc.CallCredentials] = None, |
|
wait_for_ready: Optional[bool] = None, |
|
compression: Optional[grpc.Compression] = None |
|
) -> _base_call.UnaryUnaryCall: |
|
|
|
metadata = self._init_metadata(metadata, compression) |
|
if not self._interceptors: |
|
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), |
|
metadata, credentials, wait_for_ready, |
|
self._channel, self._method, |
|
self._request_serializer, |
|
self._response_deserializer, self._loop) |
|
else: |
|
call = InterceptedUnaryUnaryCall( |
|
self._interceptors, request, timeout, metadata, credentials, |
|
wait_for_ready, self._channel, self._method, |
|
self._request_serializer, self._response_deserializer, |
|
self._loop) |
|
|
|
return call |
|
|
|
|
|
class UnaryStreamMultiCallable(_BaseMultiCallable, |
|
_base_channel.UnaryStreamMultiCallable): |
|
|
|
def __call__( |
|
self, |
|
request: Any, |
|
*, |
|
timeout: Optional[float] = None, |
|
metadata: Optional[Metadata] = None, |
|
credentials: Optional[grpc.CallCredentials] = None, |
|
wait_for_ready: Optional[bool] = None, |
|
compression: Optional[grpc.Compression] = None |
|
) -> _base_call.UnaryStreamCall: |
|
|
|
metadata = self._init_metadata(metadata, compression) |
|
deadline = _timeout_to_deadline(timeout) |
|
|
|
if not self._interceptors: |
|
call = UnaryStreamCall(request, deadline, metadata, credentials, |
|
wait_for_ready, self._channel, self._method, |
|
self._request_serializer, |
|
self._response_deserializer, self._loop) |
|
else: |
|
call = InterceptedUnaryStreamCall( |
|
self._interceptors, request, deadline, metadata, credentials, |
|
wait_for_ready, self._channel, self._method, |
|
self._request_serializer, self._response_deserializer, |
|
self._loop) |
|
|
|
return call |
|
|
|
|
|
class StreamUnaryMultiCallable(_BaseMultiCallable, |
|
_base_channel.StreamUnaryMultiCallable): |
|
|
|
def __call__( |
|
self, |
|
request_iterator: Optional[RequestIterableType] = None, |
|
timeout: Optional[float] = None, |
|
metadata: Optional[Metadata] = None, |
|
credentials: Optional[grpc.CallCredentials] = None, |
|
wait_for_ready: Optional[bool] = None, |
|
compression: Optional[grpc.Compression] = None |
|
) -> _base_call.StreamUnaryCall: |
|
|
|
metadata = self._init_metadata(metadata, compression) |
|
deadline = _timeout_to_deadline(timeout) |
|
|
|
if not self._interceptors: |
|
call = StreamUnaryCall(request_iterator, deadline, metadata, |
|
credentials, wait_for_ready, self._channel, |
|
self._method, self._request_serializer, |
|
self._response_deserializer, self._loop) |
|
else: |
|
call = InterceptedStreamUnaryCall( |
|
self._interceptors, request_iterator, deadline, metadata, |
|
credentials, wait_for_ready, self._channel, self._method, |
|
self._request_serializer, self._response_deserializer, |
|
self._loop) |
|
|
|
return call |
|
|
|
|
|
class StreamStreamMultiCallable(_BaseMultiCallable, |
|
_base_channel.StreamStreamMultiCallable): |
|
|
|
def __call__( |
|
self, |
|
request_iterator: Optional[RequestIterableType] = None, |
|
timeout: Optional[float] = None, |
|
metadata: Optional[Metadata] = None, |
|
credentials: Optional[grpc.CallCredentials] = None, |
|
wait_for_ready: Optional[bool] = None, |
|
compression: Optional[grpc.Compression] = None |
|
) -> _base_call.StreamStreamCall: |
|
|
|
metadata = self._init_metadata(metadata, compression) |
|
deadline = _timeout_to_deadline(timeout) |
|
|
|
if not self._interceptors: |
|
call = StreamStreamCall(request_iterator, deadline, metadata, |
|
credentials, wait_for_ready, self._channel, |
|
self._method, self._request_serializer, |
|
self._response_deserializer, self._loop) |
|
else: |
|
call = InterceptedStreamStreamCall( |
|
self._interceptors, request_iterator, deadline, metadata, |
|
credentials, wait_for_ready, self._channel, self._method, |
|
self._request_serializer, self._response_deserializer, |
|
self._loop) |
|
|
|
return call |
|
|
|
|
|
class Channel(_base_channel.Channel): |
|
_loop: asyncio.AbstractEventLoop |
|
_channel: cygrpc.AioChannel |
|
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor] |
|
_unary_stream_interceptors: List[UnaryStreamClientInterceptor] |
|
_stream_unary_interceptors: List[StreamUnaryClientInterceptor] |
|
_stream_stream_interceptors: List[StreamStreamClientInterceptor] |
|
|
|
def __init__(self, target: str, options: ChannelArgumentType, |
|
credentials: Optional[grpc.ChannelCredentials], |
|
compression: Optional[grpc.Compression], |
|
interceptors: Optional[Sequence[ClientInterceptor]]): |
|
"""Constructor. |
|
|
|
Args: |
|
target: The target to which to connect. |
|
options: Configuration options for the channel. |
|
credentials: A cygrpc.ChannelCredentials or None. |
|
compression: An optional value indicating the compression method to be |
|
used over the lifetime of the channel. |
|
interceptors: An optional list of interceptors that would be used for |
|
intercepting any RPC executed with that channel. |
|
""" |
|
self._unary_unary_interceptors = [] |
|
self._unary_stream_interceptors = [] |
|
self._stream_unary_interceptors = [] |
|
self._stream_stream_interceptors = [] |
|
|
|
if interceptors is not None: |
|
for interceptor in interceptors: |
|
if isinstance(interceptor, UnaryUnaryClientInterceptor): |
|
self._unary_unary_interceptors.append(interceptor) |
|
elif isinstance(interceptor, UnaryStreamClientInterceptor): |
|
self._unary_stream_interceptors.append(interceptor) |
|
elif isinstance(interceptor, StreamUnaryClientInterceptor): |
|
self._stream_unary_interceptors.append(interceptor) |
|
elif isinstance(interceptor, StreamStreamClientInterceptor): |
|
self._stream_stream_interceptors.append(interceptor) |
|
else: |
|
raise ValueError( |
|
"Interceptor {} must be ".format(interceptor) + |
|
"{} or ".format(UnaryUnaryClientInterceptor.__name__) + |
|
"{} or ".format(UnaryStreamClientInterceptor.__name__) + |
|
"{} or ".format(StreamUnaryClientInterceptor.__name__) + |
|
"{}. ".format(StreamStreamClientInterceptor.__name__)) |
|
|
|
self._loop = cygrpc.get_working_loop() |
|
self._channel = cygrpc.AioChannel( |
|
_common.encode(target), |
|
_augment_channel_arguments(options, compression), credentials, |
|
self._loop) |
|
|
|
async def __aenter__(self): |
|
return self |
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
await self._close(None) |
|
|
|
async def _close(self, grace): |
|
if self._channel.closed(): |
|
return |
|
|
|
|
|
self._channel.closing() |
|
|
|
|
|
tasks = _all_tasks() |
|
calls = [] |
|
call_tasks = [] |
|
for task in tasks: |
|
try: |
|
stack = task.get_stack(limit=1) |
|
except AttributeError as attribute_error: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'frame' in str(attribute_error): |
|
continue |
|
else: |
|
raise |
|
|
|
|
|
if not stack: |
|
continue |
|
|
|
|
|
frame = stack[0] |
|
candidate = frame.f_locals.get('self') |
|
if candidate: |
|
if isinstance(candidate, _base_call.Call): |
|
if hasattr(candidate, '_channel'): |
|
|
|
if candidate._channel is not self._channel: |
|
continue |
|
elif hasattr(candidate, '_cython_call'): |
|
|
|
if candidate._cython_call._channel is not self._channel: |
|
continue |
|
else: |
|
|
|
raise cygrpc.InternalError( |
|
f'Unrecognized call object: {candidate}') |
|
|
|
calls.append(candidate) |
|
call_tasks.append(task) |
|
|
|
|
|
|
|
if grace and call_tasks: |
|
await asyncio.wait(call_tasks, timeout=grace) |
|
|
|
|
|
for call in calls: |
|
call.cancel() |
|
|
|
|
|
self._channel.close() |
|
|
|
async def close(self, grace: Optional[float] = None): |
|
await self._close(grace) |
|
|
|
def __del__(self): |
|
if hasattr(self, '_channel'): |
|
if not self._channel.closed(): |
|
self._channel.close() |
|
|
|
def get_state(self, |
|
try_to_connect: bool = False) -> grpc.ChannelConnectivity: |
|
result = self._channel.check_connectivity_state(try_to_connect) |
|
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] |
|
|
|
async def wait_for_state_change( |
|
self, |
|
last_observed_state: grpc.ChannelConnectivity, |
|
) -> None: |
|
assert await self._channel.watch_connectivity_state( |
|
last_observed_state.value[0], None) |
|
|
|
async def channel_ready(self) -> None: |
|
state = self.get_state(try_to_connect=True) |
|
while state != grpc.ChannelConnectivity.READY: |
|
await self.wait_for_state_change(state) |
|
state = self.get_state(try_to_connect=True) |
|
|
|
def unary_unary( |
|
self, |
|
method: str, |
|
request_serializer: Optional[SerializingFunction] = None, |
|
response_deserializer: Optional[DeserializingFunction] = None |
|
) -> UnaryUnaryMultiCallable: |
|
return UnaryUnaryMultiCallable(self._channel, _common.encode(method), |
|
request_serializer, |
|
response_deserializer, |
|
self._unary_unary_interceptors, [self], |
|
self._loop) |
|
|
|
def unary_stream( |
|
self, |
|
method: str, |
|
request_serializer: Optional[SerializingFunction] = None, |
|
response_deserializer: Optional[DeserializingFunction] = None |
|
) -> UnaryStreamMultiCallable: |
|
return UnaryStreamMultiCallable(self._channel, _common.encode(method), |
|
request_serializer, |
|
response_deserializer, |
|
self._unary_stream_interceptors, [self], |
|
self._loop) |
|
|
|
def stream_unary( |
|
self, |
|
method: str, |
|
request_serializer: Optional[SerializingFunction] = None, |
|
response_deserializer: Optional[DeserializingFunction] = None |
|
) -> StreamUnaryMultiCallable: |
|
return StreamUnaryMultiCallable(self._channel, _common.encode(method), |
|
request_serializer, |
|
response_deserializer, |
|
self._stream_unary_interceptors, [self], |
|
self._loop) |
|
|
|
def stream_stream( |
|
self, |
|
method: str, |
|
request_serializer: Optional[SerializingFunction] = None, |
|
response_deserializer: Optional[DeserializingFunction] = None |
|
) -> StreamStreamMultiCallable: |
|
return StreamStreamMultiCallable(self._channel, _common.encode(method), |
|
request_serializer, |
|
response_deserializer, |
|
self._stream_stream_interceptors, |
|
[self], self._loop) |
|
|
|
|
|
def insecure_channel( |
|
target: str, |
|
options: Optional[ChannelArgumentType] = None, |
|
compression: Optional[grpc.Compression] = None, |
|
interceptors: Optional[Sequence[ClientInterceptor]] = None): |
|
"""Creates an insecure asynchronous Channel to a server. |
|
|
|
Args: |
|
target: The server address |
|
options: An optional list of key-value pairs (:term:`channel_arguments` |
|
in gRPC Core runtime) to configure the channel. |
|
compression: An optional value indicating the compression method to be |
|
used over the lifetime of the channel. |
|
interceptors: An optional sequence of interceptors that will be executed for |
|
any call executed with this channel. |
|
|
|
Returns: |
|
A Channel. |
|
""" |
|
return Channel(target, () if options is None else options, None, |
|
compression, interceptors) |
|
|
|
|
|
def secure_channel(target: str, |
|
credentials: grpc.ChannelCredentials, |
|
options: Optional[ChannelArgumentType] = None, |
|
compression: Optional[grpc.Compression] = None, |
|
interceptors: Optional[Sequence[ClientInterceptor]] = None): |
|
"""Creates a secure asynchronous Channel to a server. |
|
|
|
Args: |
|
target: The server address. |
|
credentials: A ChannelCredentials instance. |
|
options: An optional list of key-value pairs (:term:`channel_arguments` |
|
in gRPC Core runtime) to configure the channel. |
|
compression: An optional value indicating the compression method to be |
|
used over the lifetime of the channel. |
|
interceptors: An optional sequence of interceptors that will be executed for |
|
any call executed with this channel. |
|
|
|
Returns: |
|
An aio.Channel. |
|
""" |
|
return Channel(target, () if options is None else options, |
|
credentials._credentials, compression, interceptors) |
|
|