from typing import Optional, Set import grpc from tenacity import retry, stop_after_attempt, wait_exponential_jitter, retry_if_result from opentelemetry.trace import Span class RetryOnRpcErrorClientInterceptor( grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor ): """ A gRPC client interceptor that retries RPCs on specific status codes. By default, it retries on UNAVAILABLE and UNKNOWN status codes. This interceptor should be placed after the OpenTelemetry interceptor in the interceptor list. """ max_attempts: int retryable_status_codes: Set[grpc.StatusCode] def __init__( self, max_attempts: int = 5, retryable_status_codes: Set[grpc.StatusCode] = set( [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN] ), ) -> None: self.max_attempts = max_attempts self.retryable_status_codes = retryable_status_codes def _intercept_call(self, continuation, client_call_details, request_or_iterator): sleep_span: Optional[Span] = None def before_sleep(_): from chromadb.telemetry.opentelemetry import tracer nonlocal sleep_span if tracer is not None: sleep_span = tracer.start_span("Waiting to retry RPC") @retry( wait=wait_exponential_jitter(0.1, jitter=0.1), stop=stop_after_attempt(self.max_attempts), retry=retry_if_result(lambda x: x.code() in self.retryable_status_codes), before_sleep=before_sleep, ) def wrapped(*args, **kwargs): nonlocal sleep_span if sleep_span is not None: sleep_span.end() sleep_span = None return continuation(*args, **kwargs) return wrapped(client_call_details, request_or_iterator) def intercept_unary_unary(self, continuation, client_call_details, request): return self._intercept_call(continuation, client_call_details, request) def intercept_unary_stream(self, continuation, client_call_details, request): return self._intercept_call(continuation, client_call_details, request) def intercept_stream_unary( self, continuation, client_call_details, request_iterator ): return self._intercept_call(continuation, client_call_details, request_iterator) def intercept_stream_stream( self, continuation, client_call_details, request_iterator ): return self._intercept_call(continuation, client_call_details, request_iterator)