kopyl's picture
Upload folder using huggingface_hub
096c926 verified
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of gRPC Python interceptors."""
import collections
import sys
import types
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import grpc
from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import SerializingFunction
class _ServicePipeline(object):
interceptors: Tuple[grpc.ServerInterceptor]
def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
self.interceptors = tuple(interceptors)
def _continuation(self, thunk: Callable, index: int) -> Callable:
return lambda context: self._intercept_at(thunk, index, context)
def _intercept_at(
self, thunk: Callable, index: int,
context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
if index < len(self.interceptors):
interceptor = self.interceptors[index]
thunk = self._continuation(thunk, index + 1)
return interceptor.intercept_service(thunk, context)
else:
return thunk(context)
def execute(self, thunk: Callable,
context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
return self._intercept_at(thunk, 0, context)
def service_pipeline(
interceptors: Optional[Sequence[grpc.ServerInterceptor]]
) -> Optional[_ServicePipeline]:
return _ServicePipeline(interceptors) if interceptors else None
class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials',
'wait_for_ready', 'compression')),
grpc.ClientCallDetails):
pass
def _unwrap_client_call_details(
call_details: grpc.ClientCallDetails,
default_details: grpc.ClientCallDetails
) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool,
grpc.Compression]:
try:
method = call_details.method # pytype: disable=attribute-error
except AttributeError:
method = default_details.method # pytype: disable=attribute-error
try:
timeout = call_details.timeout # pytype: disable=attribute-error
except AttributeError:
timeout = default_details.timeout # pytype: disable=attribute-error
try:
metadata = call_details.metadata # pytype: disable=attribute-error
except AttributeError:
metadata = default_details.metadata # pytype: disable=attribute-error
try:
credentials = call_details.credentials # pytype: disable=attribute-error
except AttributeError:
credentials = default_details.credentials # pytype: disable=attribute-error
try:
wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error
except AttributeError:
wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error
try:
compression = call_details.compression # pytype: disable=attribute-error
except AttributeError:
compression = default_details.compression # pytype: disable=attribute-error
return method, timeout, metadata, credentials, wait_for_ready, compression
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
_exception: Exception
_traceback: types.TracebackType
def __init__(self, exception: Exception, traceback: types.TracebackType):
super(_FailureOutcome, self).__init__()
self._exception = exception
self._traceback = traceback
def initial_metadata(self) -> Optional[MetadataType]:
return None
def trailing_metadata(self) -> Optional[MetadataType]:
return None
def code(self) -> Optional[grpc.StatusCode]:
return grpc.StatusCode.INTERNAL
def details(self) -> Optional[str]:
return 'Exception raised while intercepting the RPC'
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return False
def is_active(self) -> bool:
return False
def time_remaining(self) -> Optional[float]:
return None
def running(self) -> bool:
return False
def done(self) -> bool:
return True
def result(self, ignored_timeout: Optional[float] = None):
raise self._exception
def exception(
self,
ignored_timeout: Optional[float] = None) -> Optional[Exception]:
return self._exception
def traceback(
self,
ignored_timeout: Optional[float] = None
) -> Optional[types.TracebackType]:
return self._traceback
def add_callback(self, unused_callback) -> bool:
return False
def add_done_callback(self, fn: DoneCallbackType) -> None:
fn(self)
def __iter__(self):
return self
def __next__(self):
raise self._exception
def next(self):
return self.__next__()
class _UnaryOutcome(grpc.Call, grpc.Future):
_response: Any
_call: grpc.Call
def __init__(self, response: Any, call: grpc.Call):
self._response = response
self._call = call
def initial_metadata(self) -> Optional[MetadataType]:
return self._call.initial_metadata()
def trailing_metadata(self) -> Optional[MetadataType]:
return self._call.trailing_metadata()
def code(self) -> Optional[grpc.StatusCode]:
return self._call.code()
def details(self) -> Optional[str]:
return self._call.details()
def is_active(self) -> bool:
return self._call.is_active()
def time_remaining(self) -> Optional[float]:
return self._call.time_remaining()
def cancel(self) -> bool:
return self._call.cancel()
def add_callback(self, callback) -> bool:
return self._call.add_callback(callback)
def cancelled(self) -> bool:
return False
def running(self) -> bool:
return False
def done(self) -> bool:
return True
def result(self, ignored_timeout: Optional[float] = None):
return self._response
def exception(self, ignored_timeout: Optional[float] = None):
return None
def traceback(self, ignored_timeout: Optional[float] = None):
return None
def add_done_callback(self, fn: DoneCallbackType) -> None:
fn(self)
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.UnaryUnaryClientInterceptor
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.UnaryUnaryClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
response, ignored_call = self._with_call(request,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
return response
def _with_call(
self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try:
response, call = self._thunk(new_method).with_call(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error:
return rpc_error
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
call = self._interceptor.intercept_unary_unary(continuation,
client_call_details,
request)
return call.result(), call
def with_call(
self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
return self._with_call(request,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
def future(self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_unary_unary(
continuation, client_call_details, request)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.UnaryStreamClientInterceptor
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.UnaryStreamClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)(request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_unary_stream(
continuation, client_call_details, request)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.StreamUnaryClientInterceptor
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.StreamUnaryClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
response, ignored_call = self._with_call(request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
return response
def _with_call(
self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try:
response, call = self._thunk(new_method).with_call(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error:
return rpc_error
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
call = self._interceptor.intercept_stream_unary(continuation,
client_call_details,
request_iterator)
return call.result(), call
def with_call(
self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
return self._with_call(request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression)
def future(self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_stream_unary(
continuation, client_call_details, request_iterator)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.StreamStreamClientInterceptor
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.StreamStreamClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)(request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_stream_stream(
continuation, client_call_details, request_iterator)
except Exception as exception: # pylint:disable=broad-except
return _FailureOutcome(exception, sys.exc_info()[2])
class _Channel(grpc.Channel):
_channel: grpc.Channel
_interceptor: Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]
def __init__(self, channel: grpc.Channel,
interceptor: Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]):
self._channel = channel
self._interceptor = interceptor
def subscribe(self,
callback: Callable,
try_to_connect: Optional[bool] = False):
self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback: Callable):
self._channel.unsubscribe(callback)
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.UnaryUnaryMultiCallable:
thunk = lambda m: self._channel.unary_unary(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.UnaryStreamMultiCallable:
thunk = lambda m: self._channel.unary_stream(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.StreamUnaryMultiCallable:
thunk = lambda m: self._channel.stream_unary(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.StreamStreamMultiCallable:
thunk = lambda m: self._channel.stream_stream(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
return _StreamStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def _close(self):
self._channel.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._close()
return False
def close(self):
self._channel.close()
def intercept_channel(
channel: grpc.Channel,
*interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]]]
) -> grpc.Channel:
for interceptor in reversed(list(interceptors)):
if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
raise TypeError('interceptor must be '
'grpc.UnaryUnaryClientInterceptor or '
'grpc.UnaryStreamClientInterceptor or '
'grpc.StreamUnaryClientInterceptor or '
'grpc.StreamStreamClientInterceptor or ')
channel = _Channel(channel, interceptor)
return channel