File size: 2,570 Bytes
60e3a80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Optional, Set
import grpc
from tenacity import retry, stop_after_attempt, wait_exponential_jitter, retry_if_result
from opentelemetry.trace import Span


class RetryOnRpcErrorClientInterceptor(
    grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor
):
    """
    A gRPC client interceptor that retries RPCs on specific status codes. By default, it retries on UNAVAILABLE and UNKNOWN status codes.

    This interceptor should be placed after the OpenTelemetry interceptor in the interceptor list.
    """

    max_attempts: int
    retryable_status_codes: Set[grpc.StatusCode]

    def __init__(
        self,
        max_attempts: int = 5,
        retryable_status_codes: Set[grpc.StatusCode] = set(
            [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
        ),
    ) -> None:
        self.max_attempts = max_attempts
        self.retryable_status_codes = retryable_status_codes

    def _intercept_call(self, continuation, client_call_details, request_or_iterator):
        sleep_span: Optional[Span] = None

        def before_sleep(_):
            from chromadb.telemetry.opentelemetry import tracer

            nonlocal sleep_span
            if tracer is not None:
                sleep_span = tracer.start_span("Waiting to retry RPC")

        @retry(
            wait=wait_exponential_jitter(0.1, jitter=0.1),
            stop=stop_after_attempt(self.max_attempts),
            retry=retry_if_result(lambda x: x.code() in self.retryable_status_codes),
            before_sleep=before_sleep,
        )
        def wrapped(*args, **kwargs):
            nonlocal sleep_span
            if sleep_span is not None:
                sleep_span.end()
                sleep_span = None
            return continuation(*args, **kwargs)

        return wrapped(client_call_details, request_or_iterator)

    def intercept_unary_unary(self, continuation, client_call_details, request):
        return self._intercept_call(continuation, client_call_details, request)

    def intercept_unary_stream(self, continuation, client_call_details, request):
        return self._intercept_call(continuation, client_call_details, request)

    def intercept_stream_unary(
        self, continuation, client_call_details, request_iterator
    ):
        return self._intercept_call(continuation, client_call_details, request_iterator)

    def intercept_stream_stream(
        self, continuation, client_call_details, request_iterator
    ):
        return self._intercept_call(continuation, client_call_details, request_iterator)