himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
from typing import (
TYPE_CHECKING,
Dict,
Generic,
Optional,
Tuple,
Any,
TypeVar,
Union,
cast,
)
import numpy as np
from uuid import UUID
import chromadb.utils.embedding_functions as ef
from chromadb.api.types import (
URI,
CollectionMetadata,
DataLoader,
Embedding,
Embeddings,
PyEmbedding,
Embeddable,
GetResult,
Include,
Loadable,
Metadata,
Metadatas,
Document,
Documents,
Image,
Images,
QueryResult,
URIs,
IDs,
EmbeddingFunction,
ID,
OneOrMany,
maybe_cast_one_to_many_ids,
maybe_cast_one_to_many_embedding,
maybe_cast_one_to_many_metadata,
maybe_cast_one_to_many_document,
maybe_cast_one_to_many_image,
maybe_cast_one_to_many_uri,
validate_ids,
validate_include,
validate_metadata,
validate_metadatas,
validate_embeddings,
validate_embedding_function,
validate_n_results,
validate_where,
validate_where_document,
)
# TODO: We should rename the types in chromadb.types to be Models where
# appropriate. This will help to distinguish between manipulation objects
# which are essentially API views. And the actual data models which are
# stored / retrieved / transmitted.
from chromadb.types import Collection as CollectionModel, Where, WhereDocument
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from chromadb.api import ServerAPI, AsyncServerAPI
ClientT = TypeVar("ClientT", "ServerAPI", "AsyncServerAPI")
class CollectionCommon(Generic[ClientT]):
_model: CollectionModel
_client: ClientT
_embedding_function: Optional[EmbeddingFunction[Embeddable]]
_data_loader: Optional[DataLoader[Loadable]]
def __init__(
self,
client: ClientT,
model: CollectionModel,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
):
"""Initializes a new instance of the Collection class."""
self._client = client
self._model = model
# Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
if embedding_function is not None:
validate_embedding_function(embedding_function)
self._embedding_function = embedding_function
self._data_loader = data_loader
# Expose the model properties as read-only properties on the Collection class
@property
def id(self) -> UUID:
return self._model.id
@property
def name(self) -> str:
return self._model.name
@property
def configuration_json(self) -> Dict[str, Any]:
return self._model.configuration_json
@property
def metadata(self) -> CollectionMetadata:
return cast(CollectionMetadata, self._model.metadata)
@property
def tenant(self) -> str:
return self._model.tenant
@property
def database(self) -> str:
return self._model.database
def __eq__(self, other: object) -> bool:
if not isinstance(other, CollectionCommon):
return False
id_match = self.id == other.id
name_match = self.name == other.name
configuration_match = self.configuration_json == other.configuration_json
metadata_match = self.metadata == other.metadata
tenant_match = self.tenant == other.tenant
database_match = self.database == other.database
embedding_function_match = self._embedding_function == other._embedding_function
data_loader_match = self._data_loader == other._data_loader
return (
id_match
and name_match
and configuration_match
and metadata_match
and tenant_match
and database_match
and embedding_function_match
and data_loader_match
)
def __repr__(self) -> str:
return f"Collection(id={self.id}, name={self.name})"
def get_model(self) -> CollectionModel:
return self._model
def _validate_embedding_set(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
require_embeddings_or_data: bool = True,
) -> Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[Images],
Optional[URIs],
]:
valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids))
valid_embeddings = (
validate_embeddings(
self._normalize_embeddings(maybe_cast_one_to_many_embedding(embeddings))
)
if embeddings is not None
else None
)
valid_metadatas = (
validate_metadatas(maybe_cast_one_to_many_metadata(metadatas))
if metadatas is not None
else None
)
valid_documents = (
maybe_cast_one_to_many_document(documents)
if documents is not None
else None
)
valid_images = (
maybe_cast_one_to_many_image(images) if images is not None else None
)
valid_uris = maybe_cast_one_to_many_uri(uris) if uris is not None else None
# Check that one of embeddings or ducuments or images is provided
if require_embeddings_or_data:
if (
valid_embeddings is None
and valid_documents is None
and valid_images is None
and valid_uris is None
):
raise ValueError(
"You must provide embeddings, documents, images, or uris."
)
# Only one of documents or images can be provided
if valid_documents is not None and valid_images is not None:
raise ValueError("You can only provide documents or images, not both.")
# Check that, if they're provided, the lengths of the arrays match the length of ids
if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids):
raise ValueError(
f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}"
)
if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids):
raise ValueError(
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
)
if valid_documents is not None and len(valid_documents) != len(valid_ids):
raise ValueError(
f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}"
)
if valid_images is not None and len(valid_images) != len(valid_ids):
raise ValueError(
f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}"
)
if valid_uris is not None and len(valid_uris) != len(valid_ids):
raise ValueError(
f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}"
)
return (
valid_ids,
valid_embeddings,
valid_metadatas,
valid_documents,
valid_images,
valid_uris,
)
def _validate_and_prepare_embedding_set(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
) -> Tuple[
IDs,
Embeddings,
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
]:
(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_embedding_set(
ids, embeddings, metadatas, documents, images, uris
)
# We need to compute the embeddings if they're not provided
if embeddings is None:
# At this point, we know that one of documents or images are provided from the validation above
if documents is not None:
embeddings = self._embed(input=documents)
elif images is not None:
embeddings = self._embed(input=images)
else:
if uris is None:
raise ValueError(
"You must provide either embeddings, documents, images, or uris."
)
if self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)
embeddings = self._embed(self._data_loader(uris))
return ids, embeddings, metadatas, documents, uris
def _validate_and_prepare_get_request(
self,
ids: Optional[OneOrMany[ID]],
where: Optional[Where],
where_document: Optional[WhereDocument],
include: Include,
) -> Tuple[Optional[IDs], Optional[Where], Optional[WhereDocument], Include,]:
valid_where = validate_where(where) if where else None
valid_where_document = (
validate_where_document(where_document) if where_document else None
)
valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
valid_include = validate_include(include, allow_distances=False)
if "data" in include and self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)
# We need to include uris in the result from the API to load datas
if "data" in include and "uris" not in include:
valid_include.append("uris") # type: ignore[arg-type]
return valid_ids, valid_where, valid_where_document, valid_include
def _transform_peek_response(self, response: GetResult) -> GetResult:
if response["embeddings"] is not None:
response["embeddings"] = np.array(response["embeddings"])
return response
def _transform_get_response(
self, response: GetResult, include: Include
) -> GetResult:
if (
"data" in include
and self._data_loader is not None
and response["uris"] is not None
):
response["data"] = self._data_loader(response["uris"])
if "embeddings" in include:
response["embeddings"] = np.array(response["embeddings"])
# Remove URIs from the result if they weren't requested
if "uris" not in include:
response["uris"] = None
return response
def _validate_and_prepare_query_request(
self,
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
query_texts: Optional[OneOrMany[Document]],
query_images: Optional[OneOrMany[Image]],
query_uris: Optional[OneOrMany[URI]],
n_results: int,
where: Optional[Where],
where_document: Optional[WhereDocument],
include: Include,
) -> Tuple[Embeddings, int, Where, WhereDocument,]:
# Users must provide only one of query_embeddings, query_texts, query_images, or query_uris
if not (
(query_embeddings is not None)
^ (query_texts is not None)
^ (query_images is not None)
^ (query_uris is not None)
):
raise ValueError(
"You must provide one of query_embeddings, query_texts, query_images, or query_uris."
)
valid_where = validate_where(where) if where else {}
valid_where_document = (
validate_where_document(where_document) if where_document else {}
)
valid_query_embeddings = (
validate_embeddings(
self._normalize_embeddings(
maybe_cast_one_to_many_embedding(query_embeddings)
)
)
if query_embeddings is not None
else None
)
valid_query_texts = (
maybe_cast_one_to_many_document(query_texts)
if query_texts is not None
else None
)
valid_query_images = (
maybe_cast_one_to_many_image(query_images)
if query_images is not None
else None
)
valid_query_uris = (
maybe_cast_one_to_many_uri(query_uris) if query_uris is not None else None
)
valid_include = validate_include(include, allow_distances=True)
valid_n_results = validate_n_results(n_results)
# If query_embeddings are not provided, we need to compute them from the inputs
if valid_query_embeddings is None:
if query_texts is not None:
valid_query_embeddings = self._embed(input=valid_query_texts)
elif query_images is not None:
valid_query_embeddings = self._embed(input=valid_query_images)
else:
if valid_query_uris is None:
raise ValueError(
"You must provide either query_embeddings, query_texts, query_images, or query_uris."
)
if self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)
valid_query_embeddings = self._embed(
self._data_loader(valid_query_uris)
)
if "data" in include and "uris" not in include:
valid_include.append("uris") # type: ignore[arg-type]
return (
valid_query_embeddings,
valid_n_results,
valid_where,
valid_where_document,
)
def _transform_query_response(
self, response: QueryResult, include: Include
) -> QueryResult:
if (
"data" in include
and self._data_loader is not None
and response["uris"] is not None
):
response["data"] = [self._data_loader(uris) for uris in response["uris"]]
if "embeddings" in include and response["embeddings"] is not None:
response["embeddings"] = [
np.array(embedding) for embedding in response["embeddings"]
]
# Remove URIs from the result if they weren't requested
if "uris" not in include:
response["uris"] = None
return response
def _validate_modify_request(self, metadata: Optional[CollectionMetadata]) -> None:
if metadata is not None:
validate_metadata(metadata)
if "hnsw:space" in metadata:
raise ValueError(
"Changing the distance function of a collection once it is created is not supported currently."
)
def _update_model_after_modify_success(
self, name: Optional[str], metadata: Optional[CollectionMetadata]
) -> None:
if name:
self._model["name"] = name
if metadata:
self._model["metadata"] = metadata
def _validate_and_prepare_update_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
) -> Tuple[
IDs,
Embeddings,
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
]:
(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_embedding_set(
ids,
embeddings,
metadatas,
documents,
images,
uris,
require_embeddings_or_data=False,
)
if embeddings is None:
if documents is not None:
embeddings = self._embed(input=documents)
elif images is not None:
embeddings = self._embed(input=images)
return ids, cast(Embeddings, embeddings), metadatas, documents, uris
def _validate_and_prepare_upsert_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
) -> Tuple[
IDs,
Embeddings,
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
]:
(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_embedding_set(
ids, embeddings, metadatas, documents, images, uris
)
if embeddings is None:
if documents is not None:
embeddings = self._embed(input=documents)
else:
embeddings = self._embed(input=images)
return ids, embeddings, metadatas, documents, uris
def _validate_and_prepare_delete_request(
self,
ids: Optional[IDs],
where: Optional[Where],
where_document: Optional[WhereDocument],
) -> Tuple[Optional[IDs], Optional[Where], Optional[WhereDocument]]:
ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
where = validate_where(where) if where else None
where_document = (
validate_where_document(where_document) if where_document else None
)
return (ids, where, where_document)
@staticmethod
def _normalize_embeddings(
embeddings: Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
) -> Embeddings:
return cast(Embeddings, [np.array(embedding) for embedding in embeddings])
def _embed(self, input: Any) -> Embeddings:
if self._embedding_function is None:
raise ValueError(
"You must provide an embedding function to compute embeddings."
"https://docs.trychroma.com/guides/embeddings"
)
return self._embedding_function(input=input)