Spaces:
Build error
Build error
import sys | |
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor | |
import grpc | |
import time | |
from chromadb.ingest import ( | |
Producer, | |
Consumer, | |
ConsumerCallbackFn, | |
) | |
from chromadb.proto.convert import to_proto_submit | |
from chromadb.proto.logservice_pb2 import PushLogsRequest, PullLogsRequest, LogRecord | |
from chromadb.proto.logservice_pb2_grpc import LogServiceStub | |
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor | |
from chromadb.types import ( | |
OperationRecord, | |
SeqId, | |
) | |
from chromadb.config import System | |
from chromadb.telemetry.opentelemetry import ( | |
OpenTelemetryClient, | |
OpenTelemetryGranularity, | |
add_attributes_to_current_span, | |
trace_method, | |
) | |
from overrides import override | |
from typing import Sequence, Optional, cast | |
from uuid import UUID | |
import logging | |
logger = logging.getLogger(__name__) | |
class LogService(Producer, Consumer): | |
""" | |
Distributed Chroma Log Service | |
""" | |
_log_service_stub: LogServiceStub | |
_request_timeout_seconds: int | |
_channel: grpc.Channel | |
_log_service_url: str | |
_log_service_port: int | |
def __init__(self, system: System): | |
self._log_service_url = system.settings.require("chroma_logservice_host") | |
self._log_service_port = system.settings.require("chroma_logservice_port") | |
self._request_timeout_seconds = system.settings.require( | |
"chroma_logservice_request_timeout_seconds" | |
) | |
self._opentelemetry_client = system.require(OpenTelemetryClient) | |
super().__init__(system) | |
def start(self) -> None: | |
self._channel = grpc.insecure_channel( | |
f"{self._log_service_url}:{self._log_service_port}", | |
) | |
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] | |
self._channel = grpc.intercept_channel(self._channel, *interceptors) | |
self._log_service_stub = LogServiceStub(self._channel) # type: ignore | |
super().start() | |
def stop(self) -> None: | |
self._channel.close() | |
super().stop() | |
def reset_state(self) -> None: | |
super().reset_state() | |
def delete_log(self, collection_id: UUID) -> None: | |
raise NotImplementedError("Not implemented") | |
def purge_log(self, collection_id: UUID) -> None: | |
raise NotImplementedError("Not implemented") | |
def submit_embedding( | |
self, collection_id: UUID, embedding: OperationRecord | |
) -> SeqId: | |
if not self._running: | |
raise RuntimeError("Component not running") | |
return self.submit_embeddings(collection_id, [embedding])[0] | |
def submit_embeddings( | |
self, collection_id: UUID, embeddings: Sequence[OperationRecord] | |
) -> Sequence[SeqId]: | |
logger.info( | |
f"Submitting {len(embeddings)} embeddings to log for collection {collection_id}" | |
) | |
add_attributes_to_current_span( | |
{ | |
"records_count": len(embeddings), | |
} | |
) | |
if not self._running: | |
raise RuntimeError("Component not running") | |
if len(embeddings) == 0: | |
return [] | |
# push records to the log service | |
counts = [] | |
protos_to_submit = [to_proto_submit(record) for record in embeddings] | |
counts.append( | |
self.push_logs( | |
collection_id, | |
cast(Sequence[OperationRecord], protos_to_submit), | |
) | |
) | |
# This returns counts, which is completely incorrect | |
# TODO: Fix this | |
return counts | |
def subscribe( | |
self, | |
collection_id: UUID, | |
consume_fn: ConsumerCallbackFn, | |
start: Optional[SeqId] = None, | |
end: Optional[SeqId] = None, | |
id: Optional[UUID] = None, | |
) -> UUID: | |
logger.info(f"Subscribing to log for {collection_id}, noop for logservice") | |
return UUID(int=0) | |
def unsubscribe(self, subscription_id: UUID) -> None: | |
logger.info(f"Unsubscribing from {subscription_id}, noop for logservice") | |
def min_seqid(self) -> SeqId: | |
return 0 | |
def max_seqid(self) -> SeqId: | |
return sys.maxsize | |
def max_batch_size(self) -> int: | |
return 100 | |
def push_logs(self, collection_id: UUID, records: Sequence[OperationRecord]) -> int: | |
request = PushLogsRequest(collection_id=str(collection_id), records=records) | |
response = self._log_service_stub.PushLogs( | |
request, timeout=self._request_timeout_seconds | |
) | |
return response.record_count # type: ignore | |
def pull_logs( | |
self, collection_id: UUID, start_offset: int, batch_size: int | |
) -> Sequence[LogRecord]: | |
request = PullLogsRequest( | |
collection_id=str(collection_id), | |
start_from_offset=start_offset, | |
batch_size=batch_size, | |
end_timestamp=time.time_ns(), | |
) | |
response = self._log_service_stub.PullLogs( | |
request, timeout=self._request_timeout_seconds | |
) | |
return response.records # type: ignore | |