Spaces:
Running
Running
# Copyright 2017 gRPC authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implementation of gRPC Python interceptors.""" | |
import collections | |
import sys | |
import types | |
from typing import Any, Callable, Optional, Sequence, Tuple, Union | |
import grpc | |
from ._typing import DeserializingFunction | |
from ._typing import DoneCallbackType | |
from ._typing import MetadataType | |
from ._typing import RequestIterableType | |
from ._typing import SerializingFunction | |
class _ServicePipeline(object): | |
interceptors: Tuple[grpc.ServerInterceptor] | |
def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]): | |
self.interceptors = tuple(interceptors) | |
def _continuation(self, thunk: Callable, index: int) -> Callable: | |
return lambda context: self._intercept_at(thunk, index, context) | |
def _intercept_at( | |
self, thunk: Callable, index: int, context: grpc.HandlerCallDetails | |
) -> grpc.RpcMethodHandler: | |
if index < len(self.interceptors): | |
interceptor = self.interceptors[index] | |
thunk = self._continuation(thunk, index + 1) | |
return interceptor.intercept_service(thunk, context) | |
else: | |
return thunk(context) | |
def execute( | |
self, thunk: Callable, context: grpc.HandlerCallDetails | |
) -> grpc.RpcMethodHandler: | |
return self._intercept_at(thunk, 0, context) | |
def service_pipeline( | |
interceptors: Optional[Sequence[grpc.ServerInterceptor]], | |
) -> Optional[_ServicePipeline]: | |
return _ServicePipeline(interceptors) if interceptors else None | |
class _ClientCallDetails( | |
collections.namedtuple( | |
"_ClientCallDetails", | |
( | |
"method", | |
"timeout", | |
"metadata", | |
"credentials", | |
"wait_for_ready", | |
"compression", | |
), | |
), | |
grpc.ClientCallDetails, | |
): | |
pass | |
def _unwrap_client_call_details( | |
call_details: grpc.ClientCallDetails, | |
default_details: grpc.ClientCallDetails, | |
) -> Tuple[ | |
str, float, MetadataType, grpc.CallCredentials, bool, grpc.Compression | |
]: | |
try: | |
method = call_details.method # pytype: disable=attribute-error | |
except AttributeError: | |
method = default_details.method # pytype: disable=attribute-error | |
try: | |
timeout = call_details.timeout # pytype: disable=attribute-error | |
except AttributeError: | |
timeout = default_details.timeout # pytype: disable=attribute-error | |
try: | |
metadata = call_details.metadata # pytype: disable=attribute-error | |
except AttributeError: | |
metadata = default_details.metadata # pytype: disable=attribute-error | |
try: | |
credentials = ( | |
call_details.credentials | |
) # pytype: disable=attribute-error | |
except AttributeError: | |
credentials = ( | |
default_details.credentials | |
) # pytype: disable=attribute-error | |
try: | |
wait_for_ready = ( | |
call_details.wait_for_ready | |
) # pytype: disable=attribute-error | |
except AttributeError: | |
wait_for_ready = ( | |
default_details.wait_for_ready | |
) # pytype: disable=attribute-error | |
try: | |
compression = ( | |
call_details.compression | |
) # pytype: disable=attribute-error | |
except AttributeError: | |
compression = ( | |
default_details.compression | |
) # pytype: disable=attribute-error | |
return method, timeout, metadata, credentials, wait_for_ready, compression | |
class _FailureOutcome( | |
grpc.RpcError, grpc.Future, grpc.Call | |
): # pylint: disable=too-many-ancestors | |
_exception: Exception | |
_traceback: types.TracebackType | |
def __init__(self, exception: Exception, traceback: types.TracebackType): | |
super(_FailureOutcome, self).__init__() | |
self._exception = exception | |
self._traceback = traceback | |
def initial_metadata(self) -> Optional[MetadataType]: | |
return None | |
def trailing_metadata(self) -> Optional[MetadataType]: | |
return None | |
def code(self) -> Optional[grpc.StatusCode]: | |
return grpc.StatusCode.INTERNAL | |
def details(self) -> Optional[str]: | |
return "Exception raised while intercepting the RPC" | |
def cancel(self) -> bool: | |
return False | |
def cancelled(self) -> bool: | |
return False | |
def is_active(self) -> bool: | |
return False | |
def time_remaining(self) -> Optional[float]: | |
return None | |
def running(self) -> bool: | |
return False | |
def done(self) -> bool: | |
return True | |
def result(self, ignored_timeout: Optional[float] = None): | |
raise self._exception | |
def exception( | |
self, ignored_timeout: Optional[float] = None | |
) -> Optional[Exception]: | |
return self._exception | |
def traceback( | |
self, ignored_timeout: Optional[float] = None | |
) -> Optional[types.TracebackType]: | |
return self._traceback | |
def add_callback(self, unused_callback) -> bool: | |
return False | |
def add_done_callback(self, fn: DoneCallbackType) -> None: | |
fn(self) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
raise self._exception | |
def next(self): | |
return self.__next__() | |
class _UnaryOutcome(grpc.Call, grpc.Future): | |
_response: Any | |
_call: grpc.Call | |
def __init__(self, response: Any, call: grpc.Call): | |
self._response = response | |
self._call = call | |
def initial_metadata(self) -> Optional[MetadataType]: | |
return self._call.initial_metadata() | |
def trailing_metadata(self) -> Optional[MetadataType]: | |
return self._call.trailing_metadata() | |
def code(self) -> Optional[grpc.StatusCode]: | |
return self._call.code() | |
def details(self) -> Optional[str]: | |
return self._call.details() | |
def is_active(self) -> bool: | |
return self._call.is_active() | |
def time_remaining(self) -> Optional[float]: | |
return self._call.time_remaining() | |
def cancel(self) -> bool: | |
return self._call.cancel() | |
def add_callback(self, callback) -> bool: | |
return self._call.add_callback(callback) | |
def cancelled(self) -> bool: | |
return False | |
def running(self) -> bool: | |
return False | |
def done(self) -> bool: | |
return True | |
def result(self, ignored_timeout: Optional[float] = None): | |
return self._response | |
def exception(self, ignored_timeout: Optional[float] = None): | |
return None | |
def traceback(self, ignored_timeout: Optional[float] = None): | |
return None | |
def add_done_callback(self, fn: DoneCallbackType) -> None: | |
fn(self) | |
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): | |
_thunk: Callable | |
_method: str | |
_interceptor: grpc.UnaryUnaryClientInterceptor | |
def __init__( | |
self, | |
thunk: Callable, | |
method: str, | |
interceptor: grpc.UnaryUnaryClientInterceptor, | |
): | |
self._thunk = thunk | |
self._method = method | |
self._interceptor = interceptor | |
def __call__( | |
self, | |
request: Any, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Any: | |
response, ignored_call = self._with_call( | |
request, | |
timeout=timeout, | |
metadata=metadata, | |
credentials=credentials, | |
wait_for_ready=wait_for_ready, | |
compression=compression, | |
) | |
return response | |
def _with_call( | |
self, | |
request: Any, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Tuple[Any, grpc.Call]: | |
client_call_details = _ClientCallDetails( | |
self._method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
compression, | |
) | |
def continuation(new_details, request): | |
( | |
new_method, | |
new_timeout, | |
new_metadata, | |
new_credentials, | |
new_wait_for_ready, | |
new_compression, | |
) = _unwrap_client_call_details(new_details, client_call_details) | |
try: | |
response, call = self._thunk(new_method).with_call( | |
request, | |
timeout=new_timeout, | |
metadata=new_metadata, | |
credentials=new_credentials, | |
wait_for_ready=new_wait_for_ready, | |
compression=new_compression, | |
) | |
return _UnaryOutcome(response, call) | |
except grpc.RpcError as rpc_error: | |
return rpc_error | |
except Exception as exception: # pylint:disable=broad-except | |
return _FailureOutcome(exception, sys.exc_info()[2]) | |
call = self._interceptor.intercept_unary_unary( | |
continuation, client_call_details, request | |
) | |
return call.result(), call | |
def with_call( | |
self, | |
request: Any, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Tuple[Any, grpc.Call]: | |
return self._with_call( | |
request, | |
timeout=timeout, | |
metadata=metadata, | |
credentials=credentials, | |
wait_for_ready=wait_for_ready, | |
compression=compression, | |
) | |
def future( | |
self, | |
request: Any, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Any: | |
client_call_details = _ClientCallDetails( | |
self._method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
compression, | |
) | |
def continuation(new_details, request): | |
( | |
new_method, | |
new_timeout, | |
new_metadata, | |
new_credentials, | |
new_wait_for_ready, | |
new_compression, | |
) = _unwrap_client_call_details(new_details, client_call_details) | |
return self._thunk(new_method).future( | |
request, | |
timeout=new_timeout, | |
metadata=new_metadata, | |
credentials=new_credentials, | |
wait_for_ready=new_wait_for_ready, | |
compression=new_compression, | |
) | |
try: | |
return self._interceptor.intercept_unary_unary( | |
continuation, client_call_details, request | |
) | |
except Exception as exception: # pylint:disable=broad-except | |
return _FailureOutcome(exception, sys.exc_info()[2]) | |
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): | |
_thunk: Callable | |
_method: str | |
_interceptor: grpc.UnaryStreamClientInterceptor | |
def __init__( | |
self, | |
thunk: Callable, | |
method: str, | |
interceptor: grpc.UnaryStreamClientInterceptor, | |
): | |
self._thunk = thunk | |
self._method = method | |
self._interceptor = interceptor | |
def __call__( | |
self, | |
request: Any, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
): | |
client_call_details = _ClientCallDetails( | |
self._method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
compression, | |
) | |
def continuation(new_details, request): | |
( | |
new_method, | |
new_timeout, | |
new_metadata, | |
new_credentials, | |
new_wait_for_ready, | |
new_compression, | |
) = _unwrap_client_call_details(new_details, client_call_details) | |
return self._thunk(new_method)( | |
request, | |
timeout=new_timeout, | |
metadata=new_metadata, | |
credentials=new_credentials, | |
wait_for_ready=new_wait_for_ready, | |
compression=new_compression, | |
) | |
try: | |
return self._interceptor.intercept_unary_stream( | |
continuation, client_call_details, request | |
) | |
except Exception as exception: # pylint:disable=broad-except | |
return _FailureOutcome(exception, sys.exc_info()[2]) | |
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): | |
_thunk: Callable | |
_method: str | |
_interceptor: grpc.StreamUnaryClientInterceptor | |
def __init__( | |
self, | |
thunk: Callable, | |
method: str, | |
interceptor: grpc.StreamUnaryClientInterceptor, | |
): | |
self._thunk = thunk | |
self._method = method | |
self._interceptor = interceptor | |
def __call__( | |
self, | |
request_iterator: RequestIterableType, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Any: | |
response, ignored_call = self._with_call( | |
request_iterator, | |
timeout=timeout, | |
metadata=metadata, | |
credentials=credentials, | |
wait_for_ready=wait_for_ready, | |
compression=compression, | |
) | |
return response | |
def _with_call( | |
self, | |
request_iterator: RequestIterableType, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Tuple[Any, grpc.Call]: | |
client_call_details = _ClientCallDetails( | |
self._method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
compression, | |
) | |
def continuation(new_details, request_iterator): | |
( | |
new_method, | |
new_timeout, | |
new_metadata, | |
new_credentials, | |
new_wait_for_ready, | |
new_compression, | |
) = _unwrap_client_call_details(new_details, client_call_details) | |
try: | |
response, call = self._thunk(new_method).with_call( | |
request_iterator, | |
timeout=new_timeout, | |
metadata=new_metadata, | |
credentials=new_credentials, | |
wait_for_ready=new_wait_for_ready, | |
compression=new_compression, | |
) | |
return _UnaryOutcome(response, call) | |
except grpc.RpcError as rpc_error: | |
return rpc_error | |
except Exception as exception: # pylint:disable=broad-except | |
return _FailureOutcome(exception, sys.exc_info()[2]) | |
call = self._interceptor.intercept_stream_unary( | |
continuation, client_call_details, request_iterator | |
) | |
return call.result(), call | |
def with_call( | |
self, | |
request_iterator: RequestIterableType, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Tuple[Any, grpc.Call]: | |
return self._with_call( | |
request_iterator, | |
timeout=timeout, | |
metadata=metadata, | |
credentials=credentials, | |
wait_for_ready=wait_for_ready, | |
compression=compression, | |
) | |
def future( | |
self, | |
request_iterator: RequestIterableType, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
) -> Any: | |
client_call_details = _ClientCallDetails( | |
self._method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
compression, | |
) | |
def continuation(new_details, request_iterator): | |
( | |
new_method, | |
new_timeout, | |
new_metadata, | |
new_credentials, | |
new_wait_for_ready, | |
new_compression, | |
) = _unwrap_client_call_details(new_details, client_call_details) | |
return self._thunk(new_method).future( | |
request_iterator, | |
timeout=new_timeout, | |
metadata=new_metadata, | |
credentials=new_credentials, | |
wait_for_ready=new_wait_for_ready, | |
compression=new_compression, | |
) | |
try: | |
return self._interceptor.intercept_stream_unary( | |
continuation, client_call_details, request_iterator | |
) | |
except Exception as exception: # pylint:disable=broad-except | |
return _FailureOutcome(exception, sys.exc_info()[2]) | |
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): | |
_thunk: Callable | |
_method: str | |
_interceptor: grpc.StreamStreamClientInterceptor | |
def __init__( | |
self, | |
thunk: Callable, | |
method: str, | |
interceptor: grpc.StreamStreamClientInterceptor, | |
): | |
self._thunk = thunk | |
self._method = method | |
self._interceptor = interceptor | |
def __call__( | |
self, | |
request_iterator: RequestIterableType, | |
timeout: Optional[float] = None, | |
metadata: Optional[MetadataType] = None, | |
credentials: Optional[grpc.CallCredentials] = None, | |
wait_for_ready: Optional[bool] = None, | |
compression: Optional[grpc.Compression] = None, | |
): | |
client_call_details = _ClientCallDetails( | |
self._method, | |
timeout, | |
metadata, | |
credentials, | |
wait_for_ready, | |
compression, | |
) | |
def continuation(new_details, request_iterator): | |
( | |
new_method, | |
new_timeout, | |
new_metadata, | |
new_credentials, | |
new_wait_for_ready, | |
new_compression, | |
) = _unwrap_client_call_details(new_details, client_call_details) | |
return self._thunk(new_method)( | |
request_iterator, | |
timeout=new_timeout, | |
metadata=new_metadata, | |
credentials=new_credentials, | |
wait_for_ready=new_wait_for_ready, | |
compression=new_compression, | |
) | |
try: | |
return self._interceptor.intercept_stream_stream( | |
continuation, client_call_details, request_iterator | |
) | |
except Exception as exception: # pylint:disable=broad-except | |
return _FailureOutcome(exception, sys.exc_info()[2]) | |
class _Channel(grpc.Channel): | |
_channel: grpc.Channel | |
_interceptor: Union[ | |
grpc.UnaryUnaryClientInterceptor, | |
grpc.UnaryStreamClientInterceptor, | |
grpc.StreamStreamClientInterceptor, | |
grpc.StreamUnaryClientInterceptor, | |
] | |
def __init__( | |
self, | |
channel: grpc.Channel, | |
interceptor: Union[ | |
grpc.UnaryUnaryClientInterceptor, | |
grpc.UnaryStreamClientInterceptor, | |
grpc.StreamStreamClientInterceptor, | |
grpc.StreamUnaryClientInterceptor, | |
], | |
): | |
self._channel = channel | |
self._interceptor = interceptor | |
def subscribe( | |
self, callback: Callable, try_to_connect: Optional[bool] = False | |
): | |
self._channel.subscribe(callback, try_to_connect=try_to_connect) | |
def unsubscribe(self, callback: Callable): | |
self._channel.unsubscribe(callback) | |
# pylint: disable=arguments-differ | |
def unary_unary( | |
self, | |
method: str, | |
request_serializer: Optional[SerializingFunction] = None, | |
response_deserializer: Optional[DeserializingFunction] = None, | |
_registered_method: Optional[bool] = False, | |
) -> grpc.UnaryUnaryMultiCallable: | |
# pytype: disable=wrong-arg-count | |
thunk = lambda m: self._channel.unary_unary( | |
m, | |
request_serializer, | |
response_deserializer, | |
_registered_method, | |
) | |
# pytype: enable=wrong-arg-count | |
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): | |
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) | |
else: | |
return thunk(method) | |
# pylint: disable=arguments-differ | |
def unary_stream( | |
self, | |
method: str, | |
request_serializer: Optional[SerializingFunction] = None, | |
response_deserializer: Optional[DeserializingFunction] = None, | |
_registered_method: Optional[bool] = False, | |
) -> grpc.UnaryStreamMultiCallable: | |
# pytype: disable=wrong-arg-count | |
thunk = lambda m: self._channel.unary_stream( | |
m, | |
request_serializer, | |
response_deserializer, | |
_registered_method, | |
) | |
# pytype: enable=wrong-arg-count | |
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): | |
return _UnaryStreamMultiCallable(thunk, method, self._interceptor) | |
else: | |
return thunk(method) | |
# pylint: disable=arguments-differ | |
def stream_unary( | |
self, | |
method: str, | |
request_serializer: Optional[SerializingFunction] = None, | |
response_deserializer: Optional[DeserializingFunction] = None, | |
_registered_method: Optional[bool] = False, | |
) -> grpc.StreamUnaryMultiCallable: | |
# pytype: disable=wrong-arg-count | |
thunk = lambda m: self._channel.stream_unary( | |
m, | |
request_serializer, | |
response_deserializer, | |
_registered_method, | |
) | |
# pytype: enable=wrong-arg-count | |
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): | |
return _StreamUnaryMultiCallable(thunk, method, self._interceptor) | |
else: | |
return thunk(method) | |
# pylint: disable=arguments-differ | |
def stream_stream( | |
self, | |
method: str, | |
request_serializer: Optional[SerializingFunction] = None, | |
response_deserializer: Optional[DeserializingFunction] = None, | |
_registered_method: Optional[bool] = False, | |
) -> grpc.StreamStreamMultiCallable: | |
# pytype: disable=wrong-arg-count | |
thunk = lambda m: self._channel.stream_stream( | |
m, | |
request_serializer, | |
response_deserializer, | |
_registered_method, | |
) | |
# pytype: enable=wrong-arg-count | |
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): | |
return _StreamStreamMultiCallable(thunk, method, self._interceptor) | |
else: | |
return thunk(method) | |
def _close(self): | |
self._channel.close() | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self._close() | |
return False | |
def close(self): | |
self._channel.close() | |
def intercept_channel( | |
channel: grpc.Channel, | |
*interceptors: Optional[ | |
Sequence[ | |
Union[ | |
grpc.UnaryUnaryClientInterceptor, | |
grpc.UnaryStreamClientInterceptor, | |
grpc.StreamStreamClientInterceptor, | |
grpc.StreamUnaryClientInterceptor, | |
] | |
] | |
], | |
) -> grpc.Channel: | |
for interceptor in reversed(list(interceptors)): | |
if ( | |
not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) | |
and not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) | |
and not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) | |
and not isinstance(interceptor, grpc.StreamStreamClientInterceptor) | |
): | |
raise TypeError( | |
"interceptor must be " | |
"grpc.UnaryUnaryClientInterceptor or " | |
"grpc.UnaryStreamClientInterceptor or " | |
"grpc.StreamUnaryClientInterceptor or " | |
"grpc.StreamStreamClientInterceptor or " | |
) | |
channel = _Channel(channel, interceptor) | |
return channel | |