|
import chromadb |
|
import logging |
|
from chromadb import Settings |
|
from chromadb.utils.batch_utils import create_batches |
|
|
|
from typing import Optional |
|
|
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult |
|
from open_webui.config import ( |
|
CHROMA_DATA_PATH, |
|
CHROMA_HTTP_HOST, |
|
CHROMA_HTTP_PORT, |
|
CHROMA_HTTP_HEADERS, |
|
CHROMA_HTTP_SSL, |
|
CHROMA_TENANT, |
|
CHROMA_DATABASE, |
|
CHROMA_CLIENT_AUTH_PROVIDER, |
|
CHROMA_CLIENT_AUTH_CREDENTIALS, |
|
) |
|
from open_webui.env import SRC_LOG_LEVELS |
|
|
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["RAG"]) |
|
|
|
|
|
class ChromaClient: |
|
def __init__(self): |
|
settings_dict = { |
|
"allow_reset": True, |
|
"anonymized_telemetry": False, |
|
} |
|
if CHROMA_CLIENT_AUTH_PROVIDER is not None: |
|
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER |
|
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None: |
|
settings_dict["chroma_client_auth_credentials"] = ( |
|
CHROMA_CLIENT_AUTH_CREDENTIALS |
|
) |
|
|
|
if CHROMA_HTTP_HOST != "": |
|
self.client = chromadb.HttpClient( |
|
host=CHROMA_HTTP_HOST, |
|
port=CHROMA_HTTP_PORT, |
|
headers=CHROMA_HTTP_HEADERS, |
|
ssl=CHROMA_HTTP_SSL, |
|
tenant=CHROMA_TENANT, |
|
database=CHROMA_DATABASE, |
|
settings=Settings(**settings_dict), |
|
) |
|
else: |
|
self.client = chromadb.PersistentClient( |
|
path=CHROMA_DATA_PATH, |
|
settings=Settings(**settings_dict), |
|
tenant=CHROMA_TENANT, |
|
database=CHROMA_DATABASE, |
|
) |
|
|
|
def has_collection(self, collection_name: str) -> bool: |
|
|
|
collection_names = self.client.list_collections() |
|
return collection_name in collection_names |
|
|
|
def delete_collection(self, collection_name: str): |
|
|
|
return self.client.delete_collection(name=collection_name) |
|
|
|
def search( |
|
self, collection_name: str, vectors: list[list[float | int]], limit: int |
|
) -> Optional[SearchResult]: |
|
|
|
try: |
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
result = collection.query( |
|
query_embeddings=vectors, |
|
n_results=limit, |
|
) |
|
|
|
|
|
|
|
distances: list = result["distances"][0] |
|
distances = [2 - dist for dist in distances] |
|
distances = [[dist / 2 for dist in distances]] |
|
|
|
return SearchResult( |
|
**{ |
|
"ids": result["ids"], |
|
"distances": distances, |
|
"documents": result["documents"], |
|
"metadatas": result["metadatas"], |
|
} |
|
) |
|
return None |
|
except Exception as e: |
|
return None |
|
|
|
def query( |
|
self, collection_name: str, filter: dict, limit: Optional[int] = None |
|
) -> Optional[GetResult]: |
|
|
|
try: |
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
result = collection.get( |
|
where=filter, |
|
limit=limit, |
|
) |
|
|
|
return GetResult( |
|
**{ |
|
"ids": [result["ids"]], |
|
"documents": [result["documents"]], |
|
"metadatas": [result["metadatas"]], |
|
} |
|
) |
|
return None |
|
except: |
|
return None |
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]: |
|
|
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
result = collection.get() |
|
return GetResult( |
|
**{ |
|
"ids": [result["ids"]], |
|
"documents": [result["documents"]], |
|
"metadatas": [result["metadatas"]], |
|
} |
|
) |
|
return None |
|
|
|
def insert(self, collection_name: str, items: list[VectorItem]): |
|
|
|
collection = self.client.get_or_create_collection( |
|
name=collection_name, metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
ids = [item["id"] for item in items] |
|
documents = [item["text"] for item in items] |
|
embeddings = [item["vector"] for item in items] |
|
metadatas = [item["metadata"] for item in items] |
|
|
|
for batch in create_batches( |
|
api=self.client, |
|
documents=documents, |
|
embeddings=embeddings, |
|
ids=ids, |
|
metadatas=metadatas, |
|
): |
|
collection.add(*batch) |
|
|
|
def upsert(self, collection_name: str, items: list[VectorItem]): |
|
|
|
collection = self.client.get_or_create_collection( |
|
name=collection_name, metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
ids = [item["id"] for item in items] |
|
documents = [item["text"] for item in items] |
|
embeddings = [item["vector"] for item in items] |
|
metadatas = [item["metadata"] for item in items] |
|
|
|
collection.upsert( |
|
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas |
|
) |
|
|
|
def delete( |
|
self, |
|
collection_name: str, |
|
ids: Optional[list[str]] = None, |
|
filter: Optional[dict] = None, |
|
): |
|
|
|
try: |
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
if ids: |
|
collection.delete(ids=ids) |
|
elif filter: |
|
collection.delete(where=filter) |
|
except Exception as e: |
|
|
|
log.debug( |
|
f"Attempted to delete from non-existent collection {collection_name}. Ignoring." |
|
) |
|
pass |
|
|
|
def reset(self): |
|
|
|
return self.client.reset() |
|
|