Spaces:
Running
Running
# Copyright 2015 gRPC authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Internal utilities for gRPC Python.""" | |
import collections | |
import logging | |
import threading | |
import time | |
from typing import Callable, Dict, Optional, Sequence | |
import grpc # pytype: disable=pyi-error | |
from grpc import _common # pytype: disable=pyi-error | |
from grpc._typing import DoneCallbackType | |
_LOGGER = logging.getLogger(__name__) | |
_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = ( | |
'Exception calling connectivity future "done" callback!' | |
) | |
class RpcMethodHandler( | |
collections.namedtuple( | |
"_RpcMethodHandler", | |
( | |
"request_streaming", | |
"response_streaming", | |
"request_deserializer", | |
"response_serializer", | |
"unary_unary", | |
"unary_stream", | |
"stream_unary", | |
"stream_stream", | |
), | |
), | |
grpc.RpcMethodHandler, | |
): | |
pass | |
class DictionaryGenericHandler(grpc.ServiceRpcHandler): | |
_name: str | |
_method_handlers: Dict[str, grpc.RpcMethodHandler] | |
def __init__( | |
self, service: str, method_handlers: Dict[str, grpc.RpcMethodHandler] | |
): | |
self._name = service | |
self._method_handlers = { | |
_common.fully_qualified_method(service, method): method_handler | |
for method, method_handler in method_handlers.items() | |
} | |
def service_name(self) -> str: | |
return self._name | |
def service( | |
self, handler_call_details: grpc.HandlerCallDetails | |
) -> Optional[grpc.RpcMethodHandler]: | |
details_method = handler_call_details.method | |
return self._method_handlers.get( | |
details_method | |
) # pytype: disable=attribute-error | |
class _ChannelReadyFuture(grpc.Future): | |
_condition: threading.Condition | |
_channel: grpc.Channel | |
_matured: bool | |
_cancelled: bool | |
_done_callbacks: Sequence[Callable] | |
def __init__(self, channel: grpc.Channel): | |
self._condition = threading.Condition() | |
self._channel = channel | |
self._matured = False | |
self._cancelled = False | |
self._done_callbacks = [] | |
def _block(self, timeout: Optional[float]) -> None: | |
until = None if timeout is None else time.time() + timeout | |
with self._condition: | |
while True: | |
if self._cancelled: | |
raise grpc.FutureCancelledError() | |
elif self._matured: | |
return | |
else: | |
if until is None: | |
self._condition.wait() | |
else: | |
remaining = until - time.time() | |
if remaining < 0: | |
raise grpc.FutureTimeoutError() | |
else: | |
self._condition.wait(timeout=remaining) | |
def _update(self, connectivity: Optional[grpc.ChannelConnectivity]) -> None: | |
with self._condition: | |
if ( | |
not self._cancelled | |
and connectivity is grpc.ChannelConnectivity.READY | |
): | |
self._matured = True | |
self._channel.unsubscribe(self._update) | |
self._condition.notify_all() | |
done_callbacks = tuple(self._done_callbacks) | |
self._done_callbacks = None | |
else: | |
return | |
for done_callback in done_callbacks: | |
try: | |
done_callback(self) | |
except Exception: # pylint: disable=broad-except | |
_LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE) | |
def cancel(self) -> bool: | |
with self._condition: | |
if not self._matured: | |
self._cancelled = True | |
self._channel.unsubscribe(self._update) | |
self._condition.notify_all() | |
done_callbacks = tuple(self._done_callbacks) | |
self._done_callbacks = None | |
else: | |
return False | |
for done_callback in done_callbacks: | |
try: | |
done_callback(self) | |
except Exception: # pylint: disable=broad-except | |
_LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE) | |
return True | |
def cancelled(self) -> bool: | |
with self._condition: | |
return self._cancelled | |
def running(self) -> bool: | |
with self._condition: | |
return not self._cancelled and not self._matured | |
def done(self) -> bool: | |
with self._condition: | |
return self._cancelled or self._matured | |
def result(self, timeout: Optional[float] = None) -> None: | |
self._block(timeout) | |
def exception(self, timeout: Optional[float] = None) -> None: | |
self._block(timeout) | |
def traceback(self, timeout: Optional[float] = None) -> None: | |
self._block(timeout) | |
def add_done_callback(self, fn: DoneCallbackType): | |
with self._condition: | |
if not self._cancelled and not self._matured: | |
self._done_callbacks.append(fn) | |
return | |
fn(self) | |
def start(self): | |
with self._condition: | |
self._channel.subscribe(self._update, try_to_connect=True) | |
def __del__(self): | |
with self._condition: | |
if not self._cancelled and not self._matured: | |
self._channel.unsubscribe(self._update) | |
def channel_ready_future(channel: grpc.Channel) -> _ChannelReadyFuture: | |
ready_future = _ChannelReadyFuture(channel) | |
ready_future.start() | |
return ready_future | |
def first_version_is_lower(version1: str, version2: str) -> bool: | |
""" | |
Compares two versions in the format '1.60.1' or '1.60.1.dev0'. | |
This method will be used in all stubs generated by grpcio-tools to check whether | |
the stub version is compatible with the runtime grpcio. | |
Args: | |
version1: The first version string. | |
version2: The second version string. | |
Returns: | |
True if version1 is lower, False otherwise. | |
""" | |
version1_list = version1.split(".") | |
version2_list = version2.split(".") | |
try: | |
for i in range(3): | |
if int(version1_list[i]) < int(version2_list[i]): | |
return True | |
elif int(version1_list[i]) > int(version2_list[i]): | |
return False | |
except ValueError: | |
# Return false in case we can't convert version to int. | |
return False | |
# The version without dev0 will be considered lower. | |
return len(version1_list) < len(version2_list) | |