Spaces:
Running
Running
File size: 5,816 Bytes
8437908 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import chromadb
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
)
class ChromaClient:
def __init__(self):
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collections = self.client.list_collections()
return collection_name in [collection.name for collection in collections]
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get(
where=filter,
limit=limit,
)
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
def reset(self):
# Resets the database. This will delete all collections and item entries.
return self.client.reset()
|