Spaces:
Build error
Build error
File size: 2,570 Bytes
60e3a80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import Optional, Set
import grpc
from tenacity import retry, stop_after_attempt, wait_exponential_jitter, retry_if_result
from opentelemetry.trace import Span
class RetryOnRpcErrorClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor
):
"""
A gRPC client interceptor that retries RPCs on specific status codes. By default, it retries on UNAVAILABLE and UNKNOWN status codes.
This interceptor should be placed after the OpenTelemetry interceptor in the interceptor list.
"""
max_attempts: int
retryable_status_codes: Set[grpc.StatusCode]
def __init__(
self,
max_attempts: int = 5,
retryable_status_codes: Set[grpc.StatusCode] = set(
[grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
),
) -> None:
self.max_attempts = max_attempts
self.retryable_status_codes = retryable_status_codes
def _intercept_call(self, continuation, client_call_details, request_or_iterator):
sleep_span: Optional[Span] = None
def before_sleep(_):
from chromadb.telemetry.opentelemetry import tracer
nonlocal sleep_span
if tracer is not None:
sleep_span = tracer.start_span("Waiting to retry RPC")
@retry(
wait=wait_exponential_jitter(0.1, jitter=0.1),
stop=stop_after_attempt(self.max_attempts),
retry=retry_if_result(lambda x: x.code() in self.retryable_status_codes),
before_sleep=before_sleep,
)
def wrapped(*args, **kwargs):
nonlocal sleep_span
if sleep_span is not None:
sleep_span.end()
sleep_span = None
return continuation(*args, **kwargs)
return wrapped(client_call_details, request_or_iterator)
def intercept_unary_unary(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)
def intercept_unary_stream(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)
def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(continuation, client_call_details, request_iterator)
def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(continuation, client_call_details, request_iterator)
|