# Copyright 2015 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import logging import threading from typing import Callable, Optional, Type import grpc from grpc import _common from grpc._cython import cygrpc from grpc._typing import MetadataType _LOGGER = logging.getLogger(__name__) class _AuthMetadataContext( collections.namedtuple( "AuthMetadataContext", ( "service_url", "method_name", ), ), grpc.AuthMetadataContext, ): pass class _CallbackState(object): def __init__(self): self.lock = threading.Lock() self.called = False self.exception = None class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): _state: _CallbackState _callback: Callable def __init__(self, state: _CallbackState, callback: Callable): self._state = state self._callback = callback def __call__( self, metadata: MetadataType, error: Optional[Type[BaseException]] ): with self._state.lock: if self._state.exception is None: if self._state.called: raise RuntimeError( "AuthMetadataPluginCallback invoked more than once!" ) else: self._state.called = True else: raise RuntimeError( 'AuthMetadataPluginCallback raised exception "{}"!'.format( self._state.exception ) ) if error is None: self._callback(metadata, cygrpc.StatusCode.ok, None) else: self._callback( None, cygrpc.StatusCode.internal, _common.encode(str(error)) ) class _Plugin(object): _metadata_plugin: grpc.AuthMetadataPlugin def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): self._metadata_plugin = metadata_plugin self._stored_ctx = None try: import contextvars # pylint: disable=wrong-import-position # The plugin may be invoked on a thread created by Core, which will not # have the context propagated. This context is stored and installed in # the thread invoking the plugin. self._stored_ctx = contextvars.copy_context() except ImportError: # Support versions predating contextvars. pass def __call__(self, service_url: str, method_name: str, callback: Callable): context = _AuthMetadataContext( _common.decode(service_url), _common.decode(method_name) ) callback_state = _CallbackState() try: self._metadata_plugin( context, _AuthMetadataPluginCallback(callback_state, callback) ) except Exception as exception: # pylint: disable=broad-except _LOGGER.exception( 'AuthMetadataPluginCallback "%s" raised exception!', self._metadata_plugin, ) with callback_state.lock: callback_state.exception = exception if callback_state.called: return callback( None, cygrpc.StatusCode.internal, _common.encode(str(exception)) ) def metadata_plugin_call_credentials( metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str] ) -> grpc.CallCredentials: if name is None: try: effective_name = metadata_plugin.__name__ except AttributeError: effective_name = metadata_plugin.__class__.__name__ else: effective_name = name return grpc.CallCredentials( cygrpc.MetadataPluginCallCredentials( _Plugin(metadata_plugin), _common.encode(effective_name) ) )