Spaces:
Running
Running
# Copyright 2023 The gRPC authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import abc | |
import contextlib | |
import logging | |
import threading | |
from typing import Any, Generator, Generic, List, Optional, TypeVar | |
from grpc._cython import cygrpc as _cygrpc | |
from grpc._typing import ChannelArgumentType | |
_LOGGER = logging.getLogger(__name__) | |
_channel = Any # _channel.py imports this module. | |
ClientCallTracerCapsule = TypeVar("ClientCallTracerCapsule") | |
ServerCallTracerFactoryCapsule = TypeVar("ServerCallTracerFactoryCapsule") | |
_plugin_lock: threading.RLock = threading.RLock() | |
_OBSERVABILITY_PLUGIN: Optional["ObservabilityPlugin"] = None | |
_SERVICES_TO_EXCLUDE: List[bytes] = [ | |
b"google.monitoring.v3.MetricService", | |
b"google.devtools.cloudtrace.v2.TraceService", | |
] | |
class ServerCallTracerFactory: | |
"""An encapsulation of a ServerCallTracerFactory. | |
Instances of this class can be passed to a Channel as values for the | |
grpc.experimental.server_call_tracer_factory option | |
""" | |
def __init__(self, address): | |
self._address = address | |
def __int__(self): | |
return self._address | |
class ObservabilityPlugin( | |
Generic[ClientCallTracerCapsule, ServerCallTracerFactoryCapsule], | |
metaclass=abc.ABCMeta, | |
): | |
"""Abstract base class for observability plugin. | |
*This is a semi-private class that was intended for the exclusive use of | |
the gRPC team.* | |
The ClientCallTracerCapsule and ClientCallTracerCapsule created by this | |
plugin should be inject to gRPC core using observability_init at the | |
start of a program, before any channels/servers are built. | |
Any future methods added to this interface cannot have the | |
@abc.abstractmethod annotation. | |
Attributes: | |
_stats_enabled: A bool indicates whether tracing is enabled. | |
_tracing_enabled: A bool indicates whether stats(metrics) is enabled. | |
_registered_methods: A set which stores the registered method names in | |
bytes. | |
""" | |
_tracing_enabled: bool = False | |
_stats_enabled: bool = False | |
def create_client_call_tracer( | |
self, method_name: bytes, target: bytes | |
) -> ClientCallTracerCapsule: | |
"""Creates a ClientCallTracerCapsule. | |
After register the plugin, if tracing or stats is enabled, this method | |
will be called after a call was created, the ClientCallTracer created | |
by this method will be saved to call context. | |
The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer` | |
interface and wrapped in a PyCapsule using `client_call_tracer` as name. | |
Args: | |
method_name: The method name of the call in byte format. | |
target: The channel target of the call in byte format. | |
registered_method: Wether this method is pre-registered. | |
Returns: | |
A PyCapsule which stores a ClientCallTracer object. | |
""" | |
raise NotImplementedError() | |
def delete_client_call_tracer( | |
self, client_call_tracer: ClientCallTracerCapsule | |
) -> None: | |
"""Deletes the ClientCallTracer stored in ClientCallTracerCapsule. | |
After register the plugin, if tracing or stats is enabled, this method | |
will be called at the end of the call to destroy the ClientCallTracer. | |
The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer` | |
interface and wrapped in a PyCapsule using `client_call_tracer` as name. | |
Args: | |
client_call_tracer: A PyCapsule which stores a ClientCallTracer object. | |
""" | |
raise NotImplementedError() | |
def save_trace_context( | |
self, trace_id: str, span_id: str, is_sampled: bool | |
) -> None: | |
"""Saves the trace_id and span_id related to the current span. | |
After register the plugin, if tracing is enabled, this method will be | |
called after the server finished sending response. | |
This method can be used to propagate census context. | |
Args: | |
trace_id: The identifier for the trace associated with the span as a | |
32-character hexadecimal encoded string, | |
e.g. 26ed0036f2eff2b7317bccce3e28d01f | |
span_id: The identifier for the span as a 16-character hexadecimal encoded | |
string. e.g. 113ec879e62583bc | |
is_sampled: A bool indicates whether the span is sampled. | |
""" | |
raise NotImplementedError() | |
def create_server_call_tracer_factory( | |
self, | |
*, | |
xds: bool = False, | |
) -> Optional[ServerCallTracerFactoryCapsule]: | |
"""Creates a ServerCallTracerFactoryCapsule. | |
This method will be called at server initialization time to create a | |
ServerCallTracerFactory, which will be registered to gRPC core. | |
The ServerCallTracerFactory is an object which implements | |
`grpc_core::ServerCallTracerFactory` interface and wrapped in a PyCapsule | |
using `server_call_tracer_factory` as name. | |
Args: | |
xds: Whether the server is xds server. | |
Returns: | |
A PyCapsule which stores a ServerCallTracerFactory object. Or None if | |
plugin decides not to create ServerCallTracerFactory. | |
""" | |
raise NotImplementedError() | |
def record_rpc_latency( | |
self, method: str, target: str, rpc_latency: float, status_code: Any | |
) -> None: | |
"""Record the latency of the RPC. | |
After register the plugin, if stats is enabled, this method will be | |
called at the end of each RPC. | |
Args: | |
method: The fully-qualified name of the RPC method being invoked. | |
target: The target name of the RPC method being invoked. | |
rpc_latency: The latency for the RPC in seconds, equals to the time between | |
when the client invokes the RPC and when the client receives the status. | |
status_code: An element of grpc.StatusCode in string format representing the | |
final status for the RPC. | |
""" | |
raise NotImplementedError() | |
def set_tracing(self, enable: bool) -> None: | |
"""Enable or disable tracing. | |
Args: | |
enable: A bool indicates whether tracing should be enabled. | |
""" | |
self._tracing_enabled = enable | |
def set_stats(self, enable: bool) -> None: | |
"""Enable or disable stats(metrics). | |
Args: | |
enable: A bool indicates whether stats should be enabled. | |
""" | |
self._stats_enabled = enable | |
def save_registered_method(self, method_name: bytes) -> None: | |
"""Saves the method name to registered_method list. | |
When exporting metrics, method name for unregistered methods will be replaced | |
with 'other' by default. | |
Args: | |
method_name: The method name in bytes. | |
""" | |
raise NotImplementedError() | |
def tracing_enabled(self) -> bool: | |
return self._tracing_enabled | |
def stats_enabled(self) -> bool: | |
return self._stats_enabled | |
def observability_enabled(self) -> bool: | |
return self.tracing_enabled or self.stats_enabled | |
def get_plugin() -> Generator[Optional[ObservabilityPlugin], None, None]: | |
"""Get the ObservabilityPlugin in _observability module. | |
Returns: | |
The ObservabilityPlugin currently registered with the _observability | |
module. Or None if no plugin exists at the time of calling this method. | |
""" | |
with _plugin_lock: | |
yield _OBSERVABILITY_PLUGIN | |
def set_plugin(observability_plugin: Optional[ObservabilityPlugin]) -> None: | |
"""Save ObservabilityPlugin to _observability module. | |
Args: | |
observability_plugin: The ObservabilityPlugin to save. | |
Raises: | |
ValueError: If an ObservabilityPlugin was already registered at the | |
time of calling this method. | |
""" | |
global _OBSERVABILITY_PLUGIN # pylint: disable=global-statement | |
with _plugin_lock: | |
if observability_plugin and _OBSERVABILITY_PLUGIN: | |
raise ValueError("observability_plugin was already set!") | |
_OBSERVABILITY_PLUGIN = observability_plugin | |
def observability_init(observability_plugin: ObservabilityPlugin) -> None: | |
"""Initialize observability with provided ObservabilityPlugin. | |
This method have to be called at the start of a program, before any | |
channels/servers are built. | |
Args: | |
observability_plugin: The ObservabilityPlugin to use. | |
Raises: | |
ValueError: If an ObservabilityPlugin was already registered at the | |
time of calling this method. | |
""" | |
set_plugin(observability_plugin) | |
def observability_deinit() -> None: | |
"""Clear the observability context, including ObservabilityPlugin and | |
ServerCallTracerFactory | |
This method have to be called after exit observability context so that | |
it's possible to re-initialize again. | |
""" | |
set_plugin(None) | |
_cygrpc.clear_server_call_tracer_factory() | |
def delete_call_tracer(client_call_tracer_capsule: Any) -> None: | |
"""Deletes the ClientCallTracer stored in ClientCallTracerCapsule. | |
This method will be called at the end of the call to destroy the ClientCallTracer. | |
The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer` | |
interface and wrapped in a PyCapsule using `client_call_tracer` as the name. | |
Args: | |
client_call_tracer_capsule: A PyCapsule which stores a ClientCallTracer object. | |
""" | |
with get_plugin() as plugin: | |
if plugin and plugin.observability_enabled: | |
plugin.delete_client_call_tracer(client_call_tracer_capsule) | |
def maybe_record_rpc_latency(state: "_channel._RPCState") -> None: | |
"""Record the latency of the RPC, if the plugin is registered and stats is enabled. | |
This method will be called at the end of each RPC. | |
Args: | |
state: a grpc._channel._RPCState object which contains the stats related to the | |
RPC. | |
""" | |
# TODO(xuanwn): use channel args to exclude those metrics. | |
for exclude_prefix in _SERVICES_TO_EXCLUDE: | |
if exclude_prefix in state.method.encode("utf8"): | |
return | |
with get_plugin() as plugin: | |
if plugin and plugin.stats_enabled: | |
rpc_latency_s = state.rpc_end_time - state.rpc_start_time | |
rpc_latency_ms = rpc_latency_s * 1000 | |
plugin.record_rpc_latency( | |
state.method, state.target, rpc_latency_ms, state.code | |
) | |
def create_server_call_tracer_factory_option(xds: bool) -> ChannelArgumentType: | |
with get_plugin() as plugin: | |
if plugin and plugin.stats_enabled: | |
server_call_tracer_factory_address = ( | |
_cygrpc.get_server_call_tracer_factory_address(plugin, xds) | |
) | |
if server_call_tracer_factory_address: | |
return ( | |
( | |
"grpc.experimental.server_call_tracer_factory", | |
ServerCallTracerFactory( | |
server_call_tracer_factory_address | |
), | |
), | |
) | |
return () | |