Spaces:
Running
Running
# Copyright 2019 gRPC authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Interceptors implementation of gRPC Asyncio Python.""" | |
from abc import ABCMeta | |
from abc import abstractmethod | |
import asyncio | |
import collections | |
import functools | |
from typing import ( | |
AsyncIterable, | |
Awaitable, | |
Callable, | |
Iterator, | |
List, | |
Optional, | |
Sequence, | |
Union, | |
) | |
import grpc | |
from grpc._cython import cygrpc | |
from . import _base_call | |
from ._call import AioRpcError | |
from ._call import StreamStreamCall | |
from ._call import StreamUnaryCall | |
from ._call import UnaryStreamCall | |
from ._call import UnaryUnaryCall | |
from ._call import _API_STYLE_ERROR | |
from ._call import _RPC_ALREADY_FINISHED_DETAILS | |
from ._call import _RPC_HALF_CLOSED_DETAILS | |
from ._metadata import Metadata | |
from ._typing import DeserializingFunction | |
from ._typing import DoneCallbackType | |
from ._typing import EOFType | |
from ._typing import RequestIterableType | |
from ._typing import RequestType | |
from ._typing import ResponseIterableType | |
from ._typing import ResponseType | |
from ._typing import SerializingFunction | |
from ._utils import _timeout_to_deadline | |
_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" | |
class ServerInterceptor(metaclass=ABCMeta): | |
"""Affords intercepting incoming RPCs on the service-side. | |
This is an EXPERIMENTAL API. | |
""" | |
async def intercept_service( | |
self, | |
continuation: Callable[ | |
[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] | |
], | |
handler_call_details: grpc.HandlerCallDetails, | |
) -> grpc.RpcMethodHandler: | |
"""Intercepts incoming RPCs before handing them over to a handler. | |
State can be passed from an interceptor to downstream interceptors | |
via contextvars. The first interceptor is called from an empty | |
contextvars.Context, and the same Context is used for downstream | |
interceptors and for the final handler call. Note that there are no | |
guarantees that interceptors and handlers will be called from the | |
same thread. | |
Args: | |
continuation: A function that takes a HandlerCallDetails and | |
proceeds to invoke the next interceptor in the chain, if any, | |
or the RPC handler lookup logic, with the call details passed | |
as an argument, and returns an RpcMethodHandler instance if | |
the RPC is considered serviced, or None otherwise. | |
handler_call_details: A HandlerCallDetails describing the RPC. | |
Returns: | |
An RpcMethodHandler with which the RPC may be serviced if the | |
interceptor chooses to service this RPC, or None otherwise. | |
""" | |
class ClientCallDetails( | |
collections.namedtuple( | |
"ClientCallDetails", | |
("method", "timeout", "metadata", "credentials", "wait_for_ready"), | |
), | |
grpc.ClientCallDetails, | |
): | |
"""Describes an RPC to be invoked. | |
This is an EXPERIMENTAL API. | |
Args: | |
method: The method name of the RPC. | |
timeout: An optional duration of time in seconds to allow for the RPC. | |
metadata: Optional metadata to be transmitted to the service-side of | |
the RPC. | |
credentials: An optional CallCredentials for the RPC. | |
wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. | |
""" | |
method: str | |
timeout: Optional[float] | |
metadata: Optional[Metadata] | |
credentials: Optional[grpc.CallCredentials] | |
wait_for_ready: Optional[bool] | |
class ClientInterceptor(metaclass=ABCMeta): | |
"""Base class used for all Aio Client Interceptor classes""" | |
class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): | |
"""Affords intercepting unary-unary invocations.""" | |
async def intercept_unary_unary( | |
self, | |
continuation: Callable[ | |
[ClientCallDetails, RequestType], UnaryUnaryCall | |
], | |
client_call_details: ClientCallDetails, | |
request: RequestType, | |
) -> Union[UnaryUnaryCall, ResponseType]: | |
"""Intercepts a unary-unary invocation asynchronously. | |
Args: | |
continuation: A coroutine that proceeds with the invocation by | |
executing the next interceptor in the chain or invoking the | |
actual RPC on the underlying Channel. It is the interceptor's | |
responsibility to call it if it decides to move the RPC forward. | |
The interceptor can use | |
`call = await continuation(client_call_details, request)` | |
to continue with the RPC. `continuation` returns the call to the | |
RPC. | |
client_call_details: A ClientCallDetails object describing the | |
outgoing RPC. | |
request: The request value for the RPC. | |
Returns: | |
An object with the RPC response. | |
Raises: | |
AioRpcError: Indicating that the RPC terminated with non-OK status. | |
asyncio.CancelledError: Indicating that the RPC was canceled. | |
""" | |
class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): | |
"""Affords intercepting unary-stream invocations.""" | |
async def intercept_unary_stream( | |
self, | |
continuation: Callable[ | |
[ClientCallDetails, RequestType], UnaryStreamCall | |
], | |
client_call_details: ClientCallDetails, | |
request: RequestType, | |
) -> Union[ResponseIterableType, UnaryStreamCall]: | |
"""Intercepts a unary-stream invocation asynchronously. | |
The function could return the call object or an asynchronous | |
iterator, in case of being an asyncrhonous iterator this will | |
become the source of the reads done by the caller. | |
Args: | |
continuation: A coroutine that proceeds with the invocation by | |
executing the next interceptor in the chain or invoking the | |
actual RPC on the underlying Channel. It is the interceptor's | |
responsibility to call it if it decides to move the RPC forward. | |
The interceptor can use | |
`call = await continuation(client_call_details, request)` | |
to continue with the RPC. `continuation` returns the call to the | |
RPC. | |
client_call_details: A ClientCallDetails object describing the | |
outgoing RPC. | |
request: The request value for the RPC. | |
Returns: | |
The RPC Call or an asynchronous iterator. | |
Raises: | |
AioRpcError: Indicating that the RPC terminated with non-OK status. | |
asyncio.CancelledError: Indicating that the RPC was canceled. | |
""" | |
class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): | |
"""Affords intercepting stream-unary invocations.""" | |
async def intercept_stream_unary( | |
self, | |
continuation: Callable[ | |
[ClientCallDetails, RequestType], StreamUnaryCall | |
], | |
client_call_details: ClientCallDetails, | |
request_iterator: RequestIterableType, | |
) -> StreamUnaryCall: | |
"""Intercepts a stream-unary invocation asynchronously. | |
Within the interceptor the usage of the call methods like `write` or | |
even awaiting the call should be done carefully, since the caller | |
could be expecting an untouched call, for example for start writing | |
messages to it. | |
Args: | |
continuation: A coroutine that proceeds with the invocation by | |
executing the next interceptor in the chain or invoking the | |
actual RPC on the underlying Channel. It is the interceptor's | |
responsibility to call it if it decides to move the RPC forward. | |
The interceptor can use | |
`call = await continuation(client_call_details, request_iterator)` | |
to continue with the RPC. `continuation` returns the call to the | |
RPC. | |
client_call_details: A ClientCallDetails object describing the | |
outgoing RPC. | |
request_iterator: The request iterator that will produce requests | |
for the RPC. | |
Returns: | |
The RPC Call. | |
Raises: | |
AioRpcError: Indicating that the RPC terminated with non-OK status. | |
asyncio.CancelledError: Indicating that the RPC was canceled. | |
""" | |
class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): | |
"""Affords intercepting stream-stream invocations.""" | |
async def intercept_stream_stream( | |
self, | |
continuation: Callable[ | |
[ClientCallDetails, RequestType], StreamStreamCall | |
], | |
client_call_details: ClientCallDetails, | |
request_iterator: RequestIterableType, | |
) -> Union[ResponseIterableType, StreamStreamCall]: | |
"""Intercepts a stream-stream invocation asynchronously. | |
Within the interceptor the usage of the call methods like `write` or | |
even awaiting the call should be done carefully, since the caller | |
could be expecting an untouched call, for example for start writing | |
messages to it. | |
The function could return the call object or an asynchronous | |
iterator, in case of being an asyncrhonous iterator this will | |
become the source of the reads done by the caller. | |
Args: | |
continuation: A coroutine that proceeds with the invocation by | |
executing the next interceptor in the chain or invoking the | |
actual RPC on the underlying Channel. It is the interceptor's | |
responsibility to call it if it decides to move the RPC forward. | |
The interceptor can use | |
`call = await continuation(client_call_details, request_iterator)` | |
to continue with the RPC. `continuation` returns the call to the | |
RPC. | |
client_call_details: A ClientCallDetails object describing the | |
outgoing RPC. | |
request_iterator: The request iterator that will produce requests | |
for the RPC. | |
Returns: | |
The RPC Call or an asynchronous iterator. | |
Raises: | |
AioRpcError: Indicating that the RPC terminated with non-OK status. | |
asyncio.CancelledError: Indicating that the RPC was canceled. | |
""" | |
class InterceptedCall: | |
"""Base implementation for all intercepted call arities. | |
Interceptors might have some work to do before the RPC invocation with | |
the capacity of changing the invocation parameters, and some work to do | |
after the RPC invocation with the capacity for accessing to the wrapped | |
`UnaryUnaryCall`. | |
It handles also early and later cancellations, when the RPC has not even | |
started and the execution is still held by the interceptors or when the | |
RPC has finished but again the execution is still held by the interceptors. | |
Once the RPC is finally executed, all methods are finally done against the | |
intercepted call, being at the same time the same call returned to the | |
interceptors. | |
As a base class for all of the interceptors implements the logic around | |
final status, metadata and cancellation. | |
""" | |
_interceptors_task: asyncio.Task | |
_pending_add_done_callbacks: Sequence[DoneCallbackType] | |
def __init__(self, interceptors_task: asyncio.Task) -> None: | |
self._interceptors_task = interceptors_task | |
self._pending_add_done_callbacks = [] | |
self._interceptors_task.add_done_callback( | |
self._fire_or_add_pending_done_callbacks | |
) | |
def __del__(self): | |
self.cancel() | |
def _fire_or_add_pending_done_callbacks( | |
self, interceptors_task: asyncio.Task | |
) -> None: | |
if not self._pending_add_done_callbacks: | |
return | |
call_completed = False | |
try: | |
call = interceptors_task.result() | |
if call.done(): | |
call_completed = True | |
except (AioRpcError, asyncio.CancelledError): | |
call_completed = True | |
if call_completed: | |
for callback in self._pending_add_done_callbacks: | |
callback(self) | |
else: | |
for callback in self._pending_add_done_callbacks: | |
callback = functools.partial( | |
self._wrap_add_done_callback, callback | |
) | |
call.add_done_callback(callback) | |
self._pending_add_done_callbacks = [] | |
def _wrap_add_done_callback( | |
self, callback: DoneCallbackType, unused_call: _base_call.Call | |
) -> None: | |
callback(self) | |
def cancel(self) -> bool: | |
if not self._interceptors_task.done(): | |
# There is no yet the intercepted call available, | |
# Trying to cancel it by using the generic Asyncio | |
# cancellation method. | |
return self._interceptors_task.cancel() | |
try: | |
call = self._interceptors_task.result() | |
except AioRpcError: | |
return False | |
except asyncio.CancelledError: | |
return False | |
return call.cancel() | |
def cancelled(self) -> bool: | |
if not self._interceptors_task.done(): | |
return False | |
try: | |
call = self._interceptors_task.result() | |
except AioRpcError as err: | |
return err.code() == grpc.StatusCode.CANCELLED | |
except asyncio.CancelledError: | |
return True | |
return call.cancelled() | |
def done(self) -> bool: | |
if not self._interceptors_task.done(): | |
return False | |
try: | |
call = self._interceptors_task.result() | |
except (AioRpcError, asyncio.CancelledError): | |
return True | |
return call.done() | |
def add_done_callback(self, callback: DoneCallbackType) -> None: | |
if not self._interceptors_task.done(): | |
self._pending_add_done_callbacks.append(callback) | |
return | |
try: | |
call = self._interceptors_task.result() | |
except (AioRpcError, asyncio.CancelledError): | |
callback(self) | |
return | |
if call.done(): | |
callback(self) | |
else: | |
callback = functools.partial(self._wrap_add_done_callback, callback) | |
call.add_done_callback(callback) | |
def time_remaining(self) -> Optional[float]: | |
raise NotImplementedError() | |
async def initial_metadata(self) -> Optional[Metadata]: | |
try: | |
call = await self._interceptors_task | |
except AioRpcError as err: | |
return err.initial_metadata() | |
except asyncio.CancelledError: | |
return None | |
return await call.initial_metadata() | |
async def trailing_metadata(self) -> Optional[Metadata]: | |
try: | |
call = await self._interceptors_task | |
except AioRpcError as err: | |
return err.trailing_metadata() | |
except asyncio.CancelledError: | |
return None | |
return await call.trailing_metadata() | |
async def code(self) -> grpc.StatusCode: | |
try: | |
call = await self._interceptors_task | |
except AioRpcError as err: | |
return err.code() | |
except asyncio.CancelledError: | |
return grpc.StatusCode.CANCELLED | |
return await call.code() | |
async def details(self) -> str: | |
try: | |
call = await self._interceptors_task | |
except AioRpcError as err: | |
return err.details() | |
except asyncio.CancelledError: | |
return _LOCAL_CANCELLATION_DETAILS | |
return await call.details() | |
async def debug_error_string(self) -> Optional[str]: | |
try: | |
call = await self._interceptors_task | |
except AioRpcError as err: | |
return err.debug_error_string() | |
except asyncio.CancelledError: | |
return "" | |
return await call.debug_error_string() | |
async def wait_for_connection(self) -> None: | |
call = await self._interceptors_task | |
return await call.wait_for_connection() | |
class _InterceptedUnaryResponseMixin: | |
def __await__(self): | |
call = yield from self._interceptors_task.__await__() | |
response = yield from call.__await__() | |
return response | |
class _InterceptedStreamResponseMixin: | |
_response_aiter: Optional[AsyncIterable[ResponseType]] | |
def _init_stream_response_mixin(self) -> None: | |
# Is initalized later, otherwise if the iterator is not finally | |
# consumed a logging warning is emmited by Asyncio. | |
self._response_aiter = None | |
async def _wait_for_interceptor_task_response_iterator( | |
self, | |
) -> ResponseType: | |
call = await self._interceptors_task | |
async for response in call: | |
yield response | |
def __aiter__(self) -> AsyncIterable[ResponseType]: | |
if self._response_aiter is None: | |
self._response_aiter = ( | |
self._wait_for_interceptor_task_response_iterator() | |
) | |
return self._response_aiter | |
async def read(self) -> Union[EOFType, ResponseType]: | |
if self._response_aiter is None: | |
self._response_aiter = ( | |
self._wait_for_interceptor_task_response_iterator() | |
) | |
try: | |
return await self._response_aiter.asend(None) | |
except StopAsyncIteration: | |
return cygrpc.EOF | |
class _InterceptedStreamRequestMixin: | |
_write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] | |
_write_to_iterator_queue: Optional[asyncio.Queue] | |
_status_code_task: Optional[asyncio.Task] | |
_FINISH_ITERATOR_SENTINEL = object() | |
def _init_stream_request_mixin( | |
self, request_iterator: Optional[RequestIterableType] | |
) -> RequestIterableType: | |
if request_iterator is None: | |
# We provide our own request iterator which is a proxy | |
# of the futures writes that will be done by the caller. | |
self._write_to_iterator_queue = asyncio.Queue(maxsize=1) | |
self._write_to_iterator_async_gen = ( | |
self._proxy_writes_as_request_iterator() | |
) | |
self._status_code_task = None | |
request_iterator = self._write_to_iterator_async_gen | |
else: | |
self._write_to_iterator_queue = None | |
return request_iterator | |
async def _proxy_writes_as_request_iterator(self): | |
await self._interceptors_task | |
while True: | |
value = await self._write_to_iterator_queue.get() | |
if ( | |
value | |
is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL | |
): | |
break | |
yield value | |
async def _write_to_iterator_queue_interruptible( | |
self, request: RequestType, call: InterceptedCall | |
): | |
# Write the specified 'request' to the request iterator queue using the | |
# specified 'call' to allow for interruption of the write in the case | |
# of abrupt termination of the call. | |
if self._status_code_task is None: | |
self._status_code_task = self._loop.create_task(call.code()) | |
await asyncio.wait( | |
( | |
self._loop.create_task( | |
self._write_to_iterator_queue.put(request) | |
), | |
self._status_code_task, | |
), | |
return_when=asyncio.FIRST_COMPLETED, | |
) | |
async def write(self, request: RequestType) -> None: | |
# If no queue was created it means that requests | |
# should be expected through an iterators provided | |
# by the caller. | |
if self._write_to_iterator_queue is None: | |
raise cygrpc.UsageError(_API_STYLE_ERROR) | |
try: | |
call = await self._interceptors_task | |
except (asyncio.CancelledError, AioRpcError): | |
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) | |
if call.done(): | |
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) | |
elif call._done_writing_flag: | |
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) | |
await self._write_to_iterator_queue_interruptible(request, call) | |
if call.done(): | |
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) | |
async def done_writing(self) -> None: | |
"""Signal peer that client is done writing. | |
This method is idempotent. | |
""" | |
# If no queue was created it means that requests | |
# should be expected through an iterators provided | |
# by the caller. | |
if self._write_to_iterator_queue is None: | |
raise cygrpc.UsageError(_API_STYLE_ERROR) | |
try: | |
call = await self._interceptors_task | |
except asyncio.CancelledError: | |
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) | |
await self._write_to_iterator_queue_interruptible( | |
_InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call | |
) | |
class InterceptedUnaryUnaryCall( | |
_InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall | |
): | |
"""Used for running a `UnaryUnaryCall` wrapped by interceptors. | |
For the `__await__` method is it is proxied to the intercepted call only when | |
the interceptor task is finished. | |
""" | |
_loop: asyncio.AbstractEventLoop | |
_channel: cygrpc.AioChannel | |
# pylint: disable=too-many-arguments | |
def __init__( | |
self, | |
interceptors: Sequence[UnaryUnaryClientInterceptor], | |
request: RequestType, | |
timeout: Optional[float], | |
metadata: Metadata, | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
channel: cygrpc.AioChannel, | |
method: bytes, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
loop: asyncio.AbstractEventLoop, | |
) -> None: | |
self._loop = loop | |
self._channel = channel | |
interceptors_task = loop.create_task( | |
self._invoke( | |
interceptors, | |
method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
request, | |
request_serializer, | |
response_deserializer, | |
) | |
) | |
super().__init__(interceptors_task) | |
# pylint: disable=too-many-arguments | |
async def _invoke( | |
self, | |
interceptors: Sequence[UnaryUnaryClientInterceptor], | |
method: bytes, | |
timeout: Optional[float], | |
metadata: Optional[Metadata], | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
request: RequestType, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
) -> UnaryUnaryCall: | |
"""Run the RPC call wrapped in interceptors""" | |
async def _run_interceptor( | |
interceptors: List[UnaryUnaryClientInterceptor], | |
client_call_details: ClientCallDetails, | |
request: RequestType, | |
) -> _base_call.UnaryUnaryCall: | |
if interceptors: | |
continuation = functools.partial( | |
_run_interceptor, interceptors[1:] | |
) | |
call_or_response = await interceptors[0].intercept_unary_unary( | |
continuation, client_call_details, request | |
) | |
if isinstance(call_or_response, _base_call.UnaryUnaryCall): | |
return call_or_response | |
else: | |
return UnaryUnaryCallResponse(call_or_response) | |
else: | |
return UnaryUnaryCall( | |
request, | |
_timeout_to_deadline(client_call_details.timeout), | |
client_call_details.metadata, | |
client_call_details.credentials, | |
client_call_details.wait_for_ready, | |
self._channel, | |
client_call_details.method, | |
request_serializer, | |
response_deserializer, | |
self._loop, | |
) | |
client_call_details = ClientCallDetails( | |
method, timeout, metadata, credentials, wait_for_ready | |
) | |
return await _run_interceptor( | |
list(interceptors), client_call_details, request | |
) | |
def time_remaining(self) -> Optional[float]: | |
raise NotImplementedError() | |
class InterceptedUnaryStreamCall( | |
_InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall | |
): | |
"""Used for running a `UnaryStreamCall` wrapped by interceptors.""" | |
_loop: asyncio.AbstractEventLoop | |
_channel: cygrpc.AioChannel | |
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] | |
# pylint: disable=too-many-arguments | |
def __init__( | |
self, | |
interceptors: Sequence[UnaryStreamClientInterceptor], | |
request: RequestType, | |
timeout: Optional[float], | |
metadata: Metadata, | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
channel: cygrpc.AioChannel, | |
method: bytes, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
loop: asyncio.AbstractEventLoop, | |
) -> None: | |
self._loop = loop | |
self._channel = channel | |
self._init_stream_response_mixin() | |
self._last_returned_call_from_interceptors = None | |
interceptors_task = loop.create_task( | |
self._invoke( | |
interceptors, | |
method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
request, | |
request_serializer, | |
response_deserializer, | |
) | |
) | |
super().__init__(interceptors_task) | |
# pylint: disable=too-many-arguments | |
async def _invoke( | |
self, | |
interceptors: Sequence[UnaryStreamClientInterceptor], | |
method: bytes, | |
timeout: Optional[float], | |
metadata: Optional[Metadata], | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
request: RequestType, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
) -> UnaryStreamCall: | |
"""Run the RPC call wrapped in interceptors""" | |
async def _run_interceptor( | |
interceptors: List[UnaryStreamClientInterceptor], | |
client_call_details: ClientCallDetails, | |
request: RequestType, | |
) -> _base_call.UnaryStreamCall: | |
if interceptors: | |
continuation = functools.partial( | |
_run_interceptor, interceptors[1:] | |
) | |
call_or_response_iterator = await interceptors[ | |
0 | |
].intercept_unary_stream( | |
continuation, client_call_details, request | |
) | |
if isinstance( | |
call_or_response_iterator, _base_call.UnaryStreamCall | |
): | |
self._last_returned_call_from_interceptors = ( | |
call_or_response_iterator | |
) | |
else: | |
self._last_returned_call_from_interceptors = ( | |
UnaryStreamCallResponseIterator( | |
self._last_returned_call_from_interceptors, | |
call_or_response_iterator, | |
) | |
) | |
return self._last_returned_call_from_interceptors | |
else: | |
self._last_returned_call_from_interceptors = UnaryStreamCall( | |
request, | |
_timeout_to_deadline(client_call_details.timeout), | |
client_call_details.metadata, | |
client_call_details.credentials, | |
client_call_details.wait_for_ready, | |
self._channel, | |
client_call_details.method, | |
request_serializer, | |
response_deserializer, | |
self._loop, | |
) | |
return self._last_returned_call_from_interceptors | |
client_call_details = ClientCallDetails( | |
method, timeout, metadata, credentials, wait_for_ready | |
) | |
return await _run_interceptor( | |
list(interceptors), client_call_details, request | |
) | |
def time_remaining(self) -> Optional[float]: | |
raise NotImplementedError() | |
class InterceptedStreamUnaryCall( | |
_InterceptedUnaryResponseMixin, | |
_InterceptedStreamRequestMixin, | |
InterceptedCall, | |
_base_call.StreamUnaryCall, | |
): | |
"""Used for running a `StreamUnaryCall` wrapped by interceptors. | |
For the `__await__` method is it is proxied to the intercepted call only when | |
the interceptor task is finished. | |
""" | |
_loop: asyncio.AbstractEventLoop | |
_channel: cygrpc.AioChannel | |
# pylint: disable=too-many-arguments | |
def __init__( | |
self, | |
interceptors: Sequence[StreamUnaryClientInterceptor], | |
request_iterator: Optional[RequestIterableType], | |
timeout: Optional[float], | |
metadata: Metadata, | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
channel: cygrpc.AioChannel, | |
method: bytes, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
loop: asyncio.AbstractEventLoop, | |
) -> None: | |
self._loop = loop | |
self._channel = channel | |
request_iterator = self._init_stream_request_mixin(request_iterator) | |
interceptors_task = loop.create_task( | |
self._invoke( | |
interceptors, | |
method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
request_iterator, | |
request_serializer, | |
response_deserializer, | |
) | |
) | |
super().__init__(interceptors_task) | |
# pylint: disable=too-many-arguments | |
async def _invoke( | |
self, | |
interceptors: Sequence[StreamUnaryClientInterceptor], | |
method: bytes, | |
timeout: Optional[float], | |
metadata: Optional[Metadata], | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
request_iterator: RequestIterableType, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
) -> StreamUnaryCall: | |
"""Run the RPC call wrapped in interceptors""" | |
async def _run_interceptor( | |
interceptors: Iterator[StreamUnaryClientInterceptor], | |
client_call_details: ClientCallDetails, | |
request_iterator: RequestIterableType, | |
) -> _base_call.StreamUnaryCall: | |
if interceptors: | |
continuation = functools.partial( | |
_run_interceptor, interceptors[1:] | |
) | |
return await interceptors[0].intercept_stream_unary( | |
continuation, client_call_details, request_iterator | |
) | |
else: | |
return StreamUnaryCall( | |
request_iterator, | |
_timeout_to_deadline(client_call_details.timeout), | |
client_call_details.metadata, | |
client_call_details.credentials, | |
client_call_details.wait_for_ready, | |
self._channel, | |
client_call_details.method, | |
request_serializer, | |
response_deserializer, | |
self._loop, | |
) | |
client_call_details = ClientCallDetails( | |
method, timeout, metadata, credentials, wait_for_ready | |
) | |
return await _run_interceptor( | |
list(interceptors), client_call_details, request_iterator | |
) | |
def time_remaining(self) -> Optional[float]: | |
raise NotImplementedError() | |
class InterceptedStreamStreamCall( | |
_InterceptedStreamResponseMixin, | |
_InterceptedStreamRequestMixin, | |
InterceptedCall, | |
_base_call.StreamStreamCall, | |
): | |
"""Used for running a `StreamStreamCall` wrapped by interceptors.""" | |
_loop: asyncio.AbstractEventLoop | |
_channel: cygrpc.AioChannel | |
_last_returned_call_from_interceptors = Optional[ | |
_base_call.StreamStreamCall | |
] | |
# pylint: disable=too-many-arguments | |
def __init__( | |
self, | |
interceptors: Sequence[StreamStreamClientInterceptor], | |
request_iterator: Optional[RequestIterableType], | |
timeout: Optional[float], | |
metadata: Metadata, | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
channel: cygrpc.AioChannel, | |
method: bytes, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
loop: asyncio.AbstractEventLoop, | |
) -> None: | |
self._loop = loop | |
self._channel = channel | |
self._init_stream_response_mixin() | |
request_iterator = self._init_stream_request_mixin(request_iterator) | |
self._last_returned_call_from_interceptors = None | |
interceptors_task = loop.create_task( | |
self._invoke( | |
interceptors, | |
method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
request_iterator, | |
request_serializer, | |
response_deserializer, | |
) | |
) | |
super().__init__(interceptors_task) | |
# pylint: disable=too-many-arguments | |
async def _invoke( | |
self, | |
interceptors: Sequence[StreamStreamClientInterceptor], | |
method: bytes, | |
timeout: Optional[float], | |
metadata: Optional[Metadata], | |
credentials: Optional[grpc.CallCredentials], | |
wait_for_ready: Optional[bool], | |
request_iterator: RequestIterableType, | |
request_serializer: SerializingFunction, | |
response_deserializer: DeserializingFunction, | |
) -> StreamStreamCall: | |
"""Run the RPC call wrapped in interceptors""" | |
async def _run_interceptor( | |
interceptors: List[StreamStreamClientInterceptor], | |
client_call_details: ClientCallDetails, | |
request_iterator: RequestIterableType, | |
) -> _base_call.StreamStreamCall: | |
if interceptors: | |
continuation = functools.partial( | |
_run_interceptor, interceptors[1:] | |
) | |
call_or_response_iterator = await interceptors[ | |
0 | |
].intercept_stream_stream( | |
continuation, client_call_details, request_iterator | |
) | |
if isinstance( | |
call_or_response_iterator, _base_call.StreamStreamCall | |
): | |
self._last_returned_call_from_interceptors = ( | |
call_or_response_iterator | |
) | |
else: | |
self._last_returned_call_from_interceptors = ( | |
StreamStreamCallResponseIterator( | |
self._last_returned_call_from_interceptors, | |
call_or_response_iterator, | |
) | |
) | |
return self._last_returned_call_from_interceptors | |
else: | |
self._last_returned_call_from_interceptors = StreamStreamCall( | |
request_iterator, | |
_timeout_to_deadline(client_call_details.timeout), | |
client_call_details.metadata, | |
client_call_details.credentials, | |
client_call_details.wait_for_ready, | |
self._channel, | |
client_call_details.method, | |
request_serializer, | |
response_deserializer, | |
self._loop, | |
) | |
return self._last_returned_call_from_interceptors | |
client_call_details = ClientCallDetails( | |
method, timeout, metadata, credentials, wait_for_ready | |
) | |
return await _run_interceptor( | |
list(interceptors), client_call_details, request_iterator | |
) | |
def time_remaining(self) -> Optional[float]: | |
raise NotImplementedError() | |
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): | |
"""Final UnaryUnaryCall class finished with a response.""" | |
_response: ResponseType | |
def __init__(self, response: ResponseType) -> None: | |
self._response = response | |
def cancel(self) -> bool: | |
return False | |
def cancelled(self) -> bool: | |
return False | |
def done(self) -> bool: | |
return True | |
def add_done_callback(self, unused_callback) -> None: | |
raise NotImplementedError() | |
def time_remaining(self) -> Optional[float]: | |
raise NotImplementedError() | |
async def initial_metadata(self) -> Optional[Metadata]: | |
return None | |
async def trailing_metadata(self) -> Optional[Metadata]: | |
return None | |
async def code(self) -> grpc.StatusCode: | |
return grpc.StatusCode.OK | |
async def details(self) -> str: | |
return "" | |
async def debug_error_string(self) -> Optional[str]: | |
return None | |
def __await__(self): | |
if False: # pylint: disable=using-constant-test | |
# This code path is never used, but a yield statement is needed | |
# for telling the interpreter that __await__ is a generator. | |
yield None | |
return self._response | |
async def wait_for_connection(self) -> None: | |
pass | |
class _StreamCallResponseIterator: | |
_call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall] | |
_response_iterator: AsyncIterable[ResponseType] | |
def __init__( | |
self, | |
call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall], | |
response_iterator: AsyncIterable[ResponseType], | |
) -> None: | |
self._response_iterator = response_iterator | |
self._call = call | |
def cancel(self) -> bool: | |
return self._call.cancel() | |
def cancelled(self) -> bool: | |
return self._call.cancelled() | |
def done(self) -> bool: | |
return self._call.done() | |
def add_done_callback(self, callback) -> None: | |
self._call.add_done_callback(callback) | |
def time_remaining(self) -> Optional[float]: | |
return self._call.time_remaining() | |
async def initial_metadata(self) -> Optional[Metadata]: | |
return await self._call.initial_metadata() | |
async def trailing_metadata(self) -> Optional[Metadata]: | |
return await self._call.trailing_metadata() | |
async def code(self) -> grpc.StatusCode: | |
return await self._call.code() | |
async def details(self) -> str: | |
return await self._call.details() | |
async def debug_error_string(self) -> Optional[str]: | |
return await self._call.debug_error_string() | |
def __aiter__(self): | |
return self._response_iterator.__aiter__() | |
async def wait_for_connection(self) -> None: | |
return await self._call.wait_for_connection() | |
class UnaryStreamCallResponseIterator( | |
_StreamCallResponseIterator, _base_call.UnaryStreamCall | |
): | |
"""UnaryStreamCall class wich uses an alternative response iterator.""" | |
async def read(self) -> Union[EOFType, ResponseType]: | |
# Behind the scenes everyting goes through the | |
# async iterator. So this path should not be reached. | |
raise NotImplementedError() | |
class StreamStreamCallResponseIterator( | |
_StreamCallResponseIterator, _base_call.StreamStreamCall | |
): | |
"""StreamStreamCall class wich uses an alternative response iterator.""" | |
async def read(self) -> Union[EOFType, ResponseType]: | |
# Behind the scenes everyting goes through the | |
# async iterator. So this path should not be reached. | |
raise NotImplementedError() | |
async def write(self, request: RequestType) -> None: | |
# Behind the scenes everyting goes through the | |
# async iterator provided by the InterceptedStreamStreamCall. | |
# So this path should not be reached. | |
raise NotImplementedError() | |
async def done_writing(self) -> None: | |
# Behind the scenes everyting goes through the | |
# async iterator provided by the InterceptedStreamStreamCall. | |
# So this path should not be reached. | |
raise NotImplementedError() | |
def _done_writing_flag(self) -> bool: | |
return self._call._done_writing_flag | |