Spaces:
Build error
Build error
from typing import Optional, Set | |
import grpc | |
from tenacity import retry, stop_after_attempt, wait_exponential_jitter, retry_if_result | |
from opentelemetry.trace import Span | |
class RetryOnRpcErrorClientInterceptor( | |
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor | |
): | |
""" | |
A gRPC client interceptor that retries RPCs on specific status codes. By default, it retries on UNAVAILABLE and UNKNOWN status codes. | |
This interceptor should be placed after the OpenTelemetry interceptor in the interceptor list. | |
""" | |
max_attempts: int | |
retryable_status_codes: Set[grpc.StatusCode] | |
def __init__( | |
self, | |
max_attempts: int = 5, | |
retryable_status_codes: Set[grpc.StatusCode] = set( | |
[grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN] | |
), | |
) -> None: | |
self.max_attempts = max_attempts | |
self.retryable_status_codes = retryable_status_codes | |
def _intercept_call(self, continuation, client_call_details, request_or_iterator): | |
sleep_span: Optional[Span] = None | |
def before_sleep(_): | |
from chromadb.telemetry.opentelemetry import tracer | |
nonlocal sleep_span | |
if tracer is not None: | |
sleep_span = tracer.start_span("Waiting to retry RPC") | |
def wrapped(*args, **kwargs): | |
nonlocal sleep_span | |
if sleep_span is not None: | |
sleep_span.end() | |
sleep_span = None | |
return continuation(*args, **kwargs) | |
return wrapped(client_call_details, request_or_iterator) | |
def intercept_unary_unary(self, continuation, client_call_details, request): | |
return self._intercept_call(continuation, client_call_details, request) | |
def intercept_unary_stream(self, continuation, client_call_details, request): | |
return self._intercept_call(continuation, client_call_details, request) | |
def intercept_stream_unary( | |
self, continuation, client_call_details, request_iterator | |
): | |
return self._intercept_call(continuation, client_call_details, request_iterator) | |
def intercept_stream_stream( | |
self, continuation, client_call_details, request_iterator | |
): | |
return self._intercept_call(continuation, client_call_details, request_iterator) | |