kopyl's picture
Upload folder using huggingface_hub
096c926 verified
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import sys
from typing import Any, Iterable, List, Optional, Sequence
import grpc
from grpc import _common
from grpc import _compression
from grpc import _grpcio_metadata
from grpc._cython import cygrpc
from . import _base_call
from . import _base_channel
from ._call import StreamStreamCall
from ._call import StreamUnaryCall
from ._call import UnaryStreamCall
from ._call import UnaryUnaryCall
from ._interceptor import ClientInterceptor
from ._interceptor import InterceptedStreamStreamCall
from ._interceptor import InterceptedStreamUnaryCall
from ._interceptor import InterceptedUnaryStreamCall
from ._interceptor import InterceptedUnaryUnaryCall
from ._interceptor import StreamStreamClientInterceptor
from ._interceptor import StreamUnaryClientInterceptor
from ._interceptor import UnaryStreamClientInterceptor
from ._interceptor import UnaryUnaryClientInterceptor
from ._metadata import Metadata
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import RequestIterableType
from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.Task.all_tasks()
else:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.all_tasks()
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_channel_argument = _compression.create_channel_option(
compression)
user_agent_channel_argument = ((
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),)
return tuple(base_options
) + compression_channel_argument + user_agent_channel_argument
class _BaseMultiCallable:
"""Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[ClientInterceptor]]
_references: List[Any]
_loop: asyncio.AbstractEventLoop
# pylint: disable=too-many-arguments
def __init__(
self,
channel: cygrpc.AioChannel,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[ClientInterceptor]],
references: List[Any],
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._interceptors = interceptors
self._references = references
@staticmethod
def _init_metadata(
metadata: Optional[Metadata] = None,
compression: Optional[grpc.Compression] = None) -> Metadata:
"""Based on the provided values for <metadata> or <compression> initialise the final
metadata, as it should be used for the current call.
"""
metadata = metadata or Metadata()
if compression:
metadata = Metadata(
*_compression.augment_metadata(metadata, compression))
return metadata
class UnaryUnaryMultiCallable(_BaseMultiCallable,
_base_channel.UnaryUnaryMultiCallable):
def __call__(
self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedUnaryUnaryCall(
self._interceptors, request, timeout, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class UnaryStreamMultiCallable(_BaseMultiCallable,
_base_channel.UnaryStreamMultiCallable):
def __call__(
self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedUnaryStreamCall(
self._interceptors, request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class StreamUnaryMultiCallable(_BaseMultiCallable,
_base_channel.StreamUnaryMultiCallable):
def __call__(
self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = StreamUnaryCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedStreamUnaryCall(
self._interceptors, request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class StreamStreamMultiCallable(_BaseMultiCallable,
_base_channel.StreamStreamMultiCallable):
def __call__(
self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = StreamStreamCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedStreamStreamCall(
self._interceptors, request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
class Channel(_base_channel.Channel):
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
_stream_unary_interceptors: List[StreamUnaryClientInterceptor]
_stream_stream_interceptors: List[StreamStreamClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[ClientInterceptor]]):
"""Constructor.
Args:
target: The target to which to connect.
options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
self._unary_unary_interceptors = []
self._unary_stream_interceptors = []
self._stream_unary_interceptors = []
self._stream_stream_interceptors = []
if interceptors is not None:
for interceptor in interceptors:
if isinstance(interceptor, UnaryUnaryClientInterceptor):
self._unary_unary_interceptors.append(interceptor)
elif isinstance(interceptor, UnaryStreamClientInterceptor):
self._unary_stream_interceptors.append(interceptor)
elif isinstance(interceptor, StreamUnaryClientInterceptor):
self._stream_unary_interceptors.append(interceptor)
elif isinstance(interceptor, StreamStreamClientInterceptor):
self._stream_stream_interceptors.append(interceptor)
else:
raise ValueError(
"Interceptor {} must be ".format(interceptor) +
"{} or ".format(UnaryUnaryClientInterceptor.__name__) +
"{} or ".format(UnaryStreamClientInterceptor.__name__) +
"{} or ".format(StreamUnaryClientInterceptor.__name__) +
"{}. ".format(StreamStreamClientInterceptor.__name__))
self._loop = cygrpc.get_working_loop()
self._channel = cygrpc.AioChannel(
_common.encode(target),
_augment_channel_arguments(options, compression), credentials,
self._loop)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._close(None)
async def _close(self, grace): # pylint: disable=too-many-branches
if self._channel.closed():
return
# No new calls will be accepted by the Cython channel.
self._channel.closing()
# Iterate through running tasks
tasks = _all_tasks()
calls = []
call_tasks = []
for task in tasks:
try:
stack = task.get_stack(limit=1)
except AttributeError as attribute_error:
# NOTE(lidiz) tl;dr: If the Task is created with a CPython
# object, it will trigger AttributeError.
#
# In the global finalizer, the event loop schedules
# a CPython PyAsyncGenAThrow object.
# https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
#
# However, the PyAsyncGenAThrow object is written in C and
# failed to include the normal Python frame objects. Hence,
# this exception is a false negative, and it is safe to ignore
# the failure. It is fixed by https://github.com/python/cpython/pull/18669,
# but not available until 3.9 or 3.8.3. So, we have to keep it
# for a while.
# TODO(lidiz) drop this hack after 3.8 deprecation
if 'frame' in str(attribute_error):
continue
else:
raise
# If the Task is created by a C-extension, the stack will be empty.
if not stack:
continue
# Locate ones created by `aio.Call`.
frame = stack[0]
candidate = frame.f_locals.get('self')
if candidate:
if isinstance(candidate, _base_call.Call):
if hasattr(candidate, '_channel'):
# For intercepted Call object
if candidate._channel is not self._channel:
continue
elif hasattr(candidate, '_cython_call'):
# For normal Call object
if candidate._cython_call._channel is not self._channel:
continue
else:
# Unidentified Call object
raise cygrpc.InternalError(
f'Unrecognized call object: {candidate}')
calls.append(candidate)
call_tasks.append(task)
# If needed, try to wait for them to finish.
# Call objects are not always awaitables.
if grace and call_tasks:
await asyncio.wait(call_tasks, timeout=grace)
# Time to cancel existing calls.
for call in calls:
call.cancel()
# Destroy the channel
self._channel.close()
async def close(self, grace: Optional[float] = None):
await self._close(grace)
def __del__(self):
if hasattr(self, '_channel'):
if not self._channel.closed():
self._channel.close()
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None)
async def channel_ready(self) -> None:
state = self.get_state(try_to_connect=True)
while state != grpc.ChannelConnectivity.READY:
await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True)
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryUnaryMultiCallable:
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors, [self],
self._loop)
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._unary_stream_interceptors, [self],
self._loop)
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._stream_unary_interceptors, [self],
self._loop)
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._stream_stream_interceptors,
[self], self._loop)
def insecure_channel(
target: str,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server.
Args:
target: The server address
options: An optional list of key-value pairs (:term:`channel_arguments`
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
A Channel.
"""
return Channel(target, () if options is None else options, None,
compression, interceptors)
def secure_channel(target: str,
credentials: grpc.ChannelCredentials,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates a secure asynchronous Channel to a server.
Args:
target: The server address.
credentials: A ChannelCredentials instance.
options: An optional list of key-value pairs (:term:`channel_arguments`
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
An aio.Channel.
"""
return Channel(target, () if options is None else options,
credentials._credentials, compression, interceptors)