Spaces:
Running
Running
# Copyright 2015 gRPC authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import collections | |
import logging | |
import threading | |
from typing import Callable, Optional, Type | |
import grpc | |
from grpc import _common | |
from grpc._cython import cygrpc | |
from grpc._typing import MetadataType | |
_LOGGER = logging.getLogger(__name__) | |
class _AuthMetadataContext( | |
collections.namedtuple( | |
"AuthMetadataContext", | |
( | |
"service_url", | |
"method_name", | |
), | |
), | |
grpc.AuthMetadataContext, | |
): | |
pass | |
class _CallbackState(object): | |
def __init__(self): | |
self.lock = threading.Lock() | |
self.called = False | |
self.exception = None | |
class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): | |
_state: _CallbackState | |
_callback: Callable | |
def __init__(self, state: _CallbackState, callback: Callable): | |
self._state = state | |
self._callback = callback | |
def __call__( | |
self, metadata: MetadataType, error: Optional[Type[BaseException]] | |
): | |
with self._state.lock: | |
if self._state.exception is None: | |
if self._state.called: | |
raise RuntimeError( | |
"AuthMetadataPluginCallback invoked more than once!" | |
) | |
else: | |
self._state.called = True | |
else: | |
raise RuntimeError( | |
'AuthMetadataPluginCallback raised exception "{}"!'.format( | |
self._state.exception | |
) | |
) | |
if error is None: | |
self._callback(metadata, cygrpc.StatusCode.ok, None) | |
else: | |
self._callback( | |
None, cygrpc.StatusCode.internal, _common.encode(str(error)) | |
) | |
class _Plugin(object): | |
_metadata_plugin: grpc.AuthMetadataPlugin | |
def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): | |
self._metadata_plugin = metadata_plugin | |
self._stored_ctx = None | |
try: | |
import contextvars # pylint: disable=wrong-import-position | |
# The plugin may be invoked on a thread created by Core, which will not | |
# have the context propagated. This context is stored and installed in | |
# the thread invoking the plugin. | |
self._stored_ctx = contextvars.copy_context() | |
except ImportError: | |
# Support versions predating contextvars. | |
pass | |
def __call__(self, service_url: str, method_name: str, callback: Callable): | |
context = _AuthMetadataContext( | |
_common.decode(service_url), _common.decode(method_name) | |
) | |
callback_state = _CallbackState() | |
try: | |
self._metadata_plugin( | |
context, _AuthMetadataPluginCallback(callback_state, callback) | |
) | |
except Exception as exception: # pylint: disable=broad-except | |
_LOGGER.exception( | |
'AuthMetadataPluginCallback "%s" raised exception!', | |
self._metadata_plugin, | |
) | |
with callback_state.lock: | |
callback_state.exception = exception | |
if callback_state.called: | |
return | |
callback( | |
None, cygrpc.StatusCode.internal, _common.encode(str(exception)) | |
) | |
def metadata_plugin_call_credentials( | |
metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str] | |
) -> grpc.CallCredentials: | |
if name is None: | |
try: | |
effective_name = metadata_plugin.__name__ | |
except AttributeError: | |
effective_name = metadata_plugin.__class__.__name__ | |
else: | |
effective_name = name | |
return grpc.CallCredentials( | |
cygrpc.MetadataPluginCallCredentials( | |
_Plugin(metadata_plugin), _common.encode(effective_name) | |
) | |
) | |