|
from concurrent import futures |
|
from queue import Queue |
|
from threading import Thread |
|
from typing import Generator, Tuple |
|
import grpc |
|
import pytest |
|
|
|
from chromadb.proto.convert import to_proto_submit |
|
from chromadb.proto.logservice_pb2 import PushLogsRequest, PushLogsResponse |
|
from chromadb.proto.logservice_pb2_grpc import ( |
|
LogServiceServicer, |
|
LogServiceStub, |
|
add_LogServiceServicer_to_server, |
|
) |
|
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor |
|
from chromadb.types import Operation, OperationRecord |
|
|
|
|
|
class FlakyLogServiceServicer(LogServiceServicer): |
|
num_requests_to_fail: int |
|
|
|
received_requests: Queue |
|
|
|
def __init__(self, num_requests_to_fail: int, received_requests: Queue) -> None: |
|
super().__init__() |
|
self.num_requests_to_fail = num_requests_to_fail |
|
self.received_requests = received_requests |
|
|
|
def PushLogs( |
|
self, request: PushLogsRequest, context: grpc.ServicerContext |
|
) -> PushLogsResponse: |
|
if self.num_requests_to_fail > 0: |
|
self.num_requests_to_fail -= 1 |
|
context.set_code(grpc.StatusCode.UNAVAILABLE) |
|
context.set_details("Service unavailable") |
|
self.received_requests.put({"status": "failed", "request": request}) |
|
return PushLogsResponse() |
|
|
|
self.received_requests.put({"status": "success", "request": request}) |
|
return PushLogsResponse(record_count=1) |
|
|
|
|
|
def start_server( |
|
num_requests_to_fail: int, |
|
received_requests: Queue, |
|
started_queue: Queue, |
|
stop_queue: Queue, |
|
) -> None: |
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) |
|
add_LogServiceServicer_to_server( |
|
FlakyLogServiceServicer(num_requests_to_fail, received_requests), server |
|
) |
|
server.add_insecure_port("[::]:50051") |
|
server.start() |
|
started_queue.put(1) |
|
|
|
stop_queue.get() |
|
server.stop(0) |
|
|
|
|
|
class LogServiceRetryClient: |
|
stub: LogServiceStub |
|
|
|
def __init__(self, grpc_url: str) -> None: |
|
channel = grpc.insecure_channel(grpc_url) |
|
interceptors = [RetryOnRpcErrorClientInterceptor()] |
|
channel = grpc.intercept_channel(channel, *interceptors) |
|
self.stub = LogServiceStub(channel) |
|
|
|
def push_log(self, collection_id: str, record: OperationRecord) -> None: |
|
proto_record = to_proto_submit(record) |
|
request = PushLogsRequest(collection_id=collection_id, records=[proto_record]) |
|
self.stub.PushLogs(request) |
|
|
|
|
|
NUM_REQUESTS_TO_FAIL = 3 |
|
|
|
|
|
@pytest.fixture() |
|
def client_for_flaky_server_and_received_requests() -> ( |
|
Generator[Tuple[LogServiceRetryClient, Queue], None, None] |
|
): |
|
received_requests: Queue = Queue() |
|
started_queue: Queue = Queue() |
|
stop_queue: Queue = Queue() |
|
|
|
server_thread = Thread( |
|
target=start_server, |
|
args=(NUM_REQUESTS_TO_FAIL, received_requests, started_queue, stop_queue), |
|
) |
|
server_thread.start() |
|
|
|
started_queue.get() |
|
|
|
client = LogServiceRetryClient("localhost:50051") |
|
|
|
yield client, received_requests |
|
|
|
stop_queue.put(1) |
|
server_thread.join() |
|
|
|
|
|
def test_retry_interceptor( |
|
client_for_flaky_server_and_received_requests: Tuple[LogServiceRetryClient, Queue] |
|
) -> None: |
|
(client, received_requests) = client_for_flaky_server_and_received_requests |
|
client = LogServiceRetryClient("localhost:50051") |
|
client.push_log( |
|
"test", |
|
OperationRecord( |
|
id="1", |
|
embedding=None, |
|
encoding=None, |
|
metadata=None, |
|
operation=Operation.ADD, |
|
), |
|
) |
|
|
|
requests = [] |
|
while not received_requests.empty(): |
|
requests.append(received_requests.get()) |
|
|
|
|
|
assert len(requests) == NUM_REQUESTS_TO_FAIL + 1 |
|
assert all(r["status"] == "failed" for r in requests[:NUM_REQUESTS_TO_FAIL]) |
|
assert requests[NUM_REQUESTS_TO_FAIL]["status"] == "success" |
|
|