Spaces:
Sleeping
Sleeping
import chromadb | |
from chromadb import Settings | |
from chromadb.utils.batch_utils import create_batches | |
from typing import Optional | |
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult | |
from open_webui.config import ( | |
CHROMA_DATA_PATH, | |
CHROMA_HTTP_HOST, | |
CHROMA_HTTP_PORT, | |
CHROMA_HTTP_HEADERS, | |
CHROMA_HTTP_SSL, | |
CHROMA_TENANT, | |
CHROMA_DATABASE, | |
) | |
class ChromaClient: | |
def __init__(self): | |
if CHROMA_HTTP_HOST != "": | |
self.client = chromadb.HttpClient( | |
host=CHROMA_HTTP_HOST, | |
port=CHROMA_HTTP_PORT, | |
headers=CHROMA_HTTP_HEADERS, | |
ssl=CHROMA_HTTP_SSL, | |
tenant=CHROMA_TENANT, | |
database=CHROMA_DATABASE, | |
settings=Settings(allow_reset=True, anonymized_telemetry=False), | |
) | |
else: | |
self.client = chromadb.PersistentClient( | |
path=CHROMA_DATA_PATH, | |
settings=Settings(allow_reset=True, anonymized_telemetry=False), | |
tenant=CHROMA_TENANT, | |
database=CHROMA_DATABASE, | |
) | |
def has_collection(self, collection_name: str) -> bool: | |
# Check if the collection exists based on the collection name. | |
collections = self.client.list_collections() | |
return collection_name in [collection.name for collection in collections] | |
def delete_collection(self, collection_name: str): | |
# Delete the collection based on the collection name. | |
return self.client.delete_collection(name=collection_name) | |
def search( | |
self, collection_name: str, vectors: list[list[float | int]], limit: int | |
) -> Optional[SearchResult]: | |
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. | |
collection = self.client.get_collection(name=collection_name) | |
if collection: | |
result = collection.query( | |
query_embeddings=vectors, | |
n_results=limit, | |
) | |
return SearchResult( | |
**{ | |
"ids": result["ids"], | |
"distances": result["distances"], | |
"documents": result["documents"], | |
"metadatas": result["metadatas"], | |
} | |
) | |
return None | |
def get(self, collection_name: str) -> Optional[GetResult]: | |
# Get all the items in the collection. | |
collection = self.client.get_collection(name=collection_name) | |
if collection: | |
result = collection.get() | |
return GetResult( | |
**{ | |
"ids": [result["ids"]], | |
"documents": [result["documents"]], | |
"metadatas": [result["metadatas"]], | |
} | |
) | |
return None | |
def insert(self, collection_name: str, items: list[VectorItem]): | |
# Insert the items into the collection, if the collection does not exist, it will be created. | |
collection = self.client.get_or_create_collection(name=collection_name) | |
ids = [item["id"] for item in items] | |
documents = [item["text"] for item in items] | |
embeddings = [item["vector"] for item in items] | |
metadatas = [item["metadata"] for item in items] | |
for batch in create_batches( | |
api=self.client, | |
documents=documents, | |
embeddings=embeddings, | |
ids=ids, | |
metadatas=metadatas, | |
): | |
collection.add(*batch) | |
def upsert(self, collection_name: str, items: list[VectorItem]): | |
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. | |
collection = self.client.get_or_create_collection(name=collection_name) | |
ids = [item["id"] for item in items] | |
documents = [item["text"] for item in items] | |
embeddings = [item["vector"] for item in items] | |
metadatas = [item["metadata"] for item in items] | |
collection.upsert( | |
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas | |
) | |
def delete(self, collection_name: str, ids: list[str]): | |
# Delete the items from the collection based on the ids. | |
collection = self.client.get_collection(name=collection_name) | |
if collection: | |
collection.delete(ids=ids) | |
def reset(self): | |
# Resets the database. This will delete all collections and item entries. | |
return self.client.reset() | |