Spaces:
Build error
Build error
from elasticsearch import Elasticsearch, BadRequestError | |
from typing import Optional | |
import ssl | |
from elasticsearch.helpers import bulk, scan | |
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult | |
from open_webui.config import ( | |
ELASTICSEARCH_URL, | |
ELASTICSEARCH_CA_CERTS, | |
ELASTICSEARCH_API_KEY, | |
ELASTICSEARCH_USERNAME, | |
ELASTICSEARCH_PASSWORD, | |
ELASTICSEARCH_CLOUD_ID, | |
ELASTICSEARCH_INDEX_PREFIX, | |
SSL_ASSERT_FINGERPRINT, | |
) | |
class ElasticsearchClient: | |
""" | |
Important: | |
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating | |
an index for each file but store it as a text field, while seperating to different index | |
baesd on the embedding length. | |
""" | |
def __init__(self): | |
self.index_prefix = ELASTICSEARCH_INDEX_PREFIX | |
self.client = Elasticsearch( | |
hosts=[ELASTICSEARCH_URL], | |
ca_certs=ELASTICSEARCH_CA_CERTS, | |
api_key=ELASTICSEARCH_API_KEY, | |
cloud_id=ELASTICSEARCH_CLOUD_ID, | |
basic_auth=( | |
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD) | |
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD | |
else None | |
), | |
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT, | |
) | |
# Status: works | |
def _get_index_name(self, dimension: int) -> str: | |
return f"{self.index_prefix}_d{str(dimension)}" | |
# Status: works | |
def _scan_result_to_get_result(self, result) -> GetResult: | |
if not result: | |
return None | |
ids = [] | |
documents = [] | |
metadatas = [] | |
for hit in result: | |
ids.append(hit["_id"]) | |
documents.append(hit["_source"].get("text")) | |
metadatas.append(hit["_source"].get("metadata")) | |
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) | |
# Status: works | |
def _result_to_get_result(self, result) -> GetResult: | |
if not result["hits"]["hits"]: | |
return None | |
ids = [] | |
documents = [] | |
metadatas = [] | |
for hit in result["hits"]["hits"]: | |
ids.append(hit["_id"]) | |
documents.append(hit["_source"].get("text")) | |
metadatas.append(hit["_source"].get("metadata")) | |
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) | |
# Status: works | |
def _result_to_search_result(self, result) -> SearchResult: | |
ids = [] | |
distances = [] | |
documents = [] | |
metadatas = [] | |
for hit in result["hits"]["hits"]: | |
ids.append(hit["_id"]) | |
distances.append(hit["_score"]) | |
documents.append(hit["_source"].get("text")) | |
metadatas.append(hit["_source"].get("metadata")) | |
return SearchResult( | |
ids=[ids], | |
distances=[distances], | |
documents=[documents], | |
metadatas=[metadatas], | |
) | |
# Status: works | |
def _create_index(self, dimension: int): | |
body = { | |
"mappings": { | |
"dynamic_templates": [ | |
{ | |
"strings": { | |
"match_mapping_type": "string", | |
"mapping": {"type": "keyword"}, | |
} | |
} | |
], | |
"properties": { | |
"collection": {"type": "keyword"}, | |
"id": {"type": "keyword"}, | |
"vector": { | |
"type": "dense_vector", | |
"dims": dimension, # Adjust based on your vector dimensions | |
"index": True, | |
"similarity": "cosine", | |
}, | |
"text": {"type": "text"}, | |
"metadata": {"type": "object"}, | |
}, | |
} | |
} | |
self.client.indices.create(index=self._get_index_name(dimension), body=body) | |
# Status: works | |
def _create_batches(self, items: list[VectorItem], batch_size=100): | |
for i in range(0, len(items), batch_size): | |
yield items[i : min(i + batch_size, len(items))] | |
# Status: works | |
def has_collection(self, collection_name) -> bool: | |
query_body = {"query": {"bool": {"filter": []}}} | |
query_body["query"]["bool"]["filter"].append( | |
{"term": {"collection": collection_name}} | |
) | |
try: | |
result = self.client.count(index=f"{self.index_prefix}*", body=query_body) | |
return result.body["count"] > 0 | |
except Exception as e: | |
return None | |
def delete_collection(self, collection_name: str): | |
query = {"query": {"term": {"collection": collection_name}}} | |
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) | |
# Status: works | |
def search( | |
self, collection_name: str, vectors: list[list[float]], limit: int | |
) -> Optional[SearchResult]: | |
query = { | |
"size": limit, | |
"_source": ["text", "metadata"], | |
"query": { | |
"script_score": { | |
"query": { | |
"bool": {"filter": [{"term": {"collection": collection_name}}]} | |
}, | |
"script": { | |
"source": "cosineSimilarity(params.vector, 'vector') + 1.0", | |
"params": { | |
"vector": vectors[0] | |
}, # Assuming single query vector | |
}, | |
} | |
}, | |
} | |
result = self.client.search( | |
index=self._get_index_name(len(vectors[0])), body=query | |
) | |
return self._result_to_search_result(result) | |
# Status: only tested halfwat | |
def query( | |
self, collection_name: str, filter: dict, limit: Optional[int] = None | |
) -> Optional[GetResult]: | |
if not self.has_collection(collection_name): | |
return None | |
query_body = { | |
"query": {"bool": {"filter": []}}, | |
"_source": ["text", "metadata"], | |
} | |
for field, value in filter.items(): | |
query_body["query"]["bool"]["filter"].append({"term": {field: value}}) | |
query_body["query"]["bool"]["filter"].append( | |
{"term": {"collection": collection_name}} | |
) | |
size = limit if limit else 10 | |
try: | |
result = self.client.search( | |
index=f"{self.index_prefix}*", | |
body=query_body, | |
size=size, | |
) | |
return self._result_to_get_result(result) | |
except Exception as e: | |
return None | |
# Status: works | |
def _has_index(self, dimension: int): | |
return self.client.indices.exists( | |
index=self._get_index_name(dimension=dimension) | |
) | |
def get_or_create_index(self, dimension: int): | |
if not self._has_index(dimension=dimension): | |
self._create_index(dimension=dimension) | |
# Status: works | |
def get(self, collection_name: str) -> Optional[GetResult]: | |
# Get all the items in the collection. | |
query = { | |
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}, | |
"_source": ["text", "metadata"], | |
} | |
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query)) | |
return self._scan_result_to_get_result(results) | |
# Status: works | |
def insert(self, collection_name: str, items: list[VectorItem]): | |
if not self._has_index(dimension=len(items[0]["vector"])): | |
self._create_index(dimension=len(items[0]["vector"])) | |
for batch in self._create_batches(items): | |
actions = [ | |
{ | |
"_index": self._get_index_name(dimension=len(items[0]["vector"])), | |
"_id": item["id"], | |
"_source": { | |
"collection": collection_name, | |
"vector": item["vector"], | |
"text": item["text"], | |
"metadata": item["metadata"], | |
}, | |
} | |
for item in batch | |
] | |
bulk(self.client, actions) | |
# Upsert documents using the update API with doc_as_upsert=True. | |
def upsert(self, collection_name: str, items: list[VectorItem]): | |
if not self._has_index(dimension=len(items[0]["vector"])): | |
self._create_index(dimension=len(items[0]["vector"])) | |
for batch in self._create_batches(items): | |
actions = [ | |
{ | |
"_op_type": "update", | |
"_index": self._get_index_name(dimension=len(item["vector"])), | |
"_id": item["id"], | |
"doc": { | |
"collection": collection_name, | |
"vector": item["vector"], | |
"text": item["text"], | |
"metadata": item["metadata"], | |
}, | |
"doc_as_upsert": True, | |
} | |
for item in batch | |
] | |
bulk(self.client, actions) | |
# Delete specific documents from a collection by filtering on both collection and document IDs. | |
def delete( | |
self, | |
collection_name: str, | |
ids: Optional[list[str]] = None, | |
filter: Optional[dict] = None, | |
): | |
query = { | |
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}} | |
} | |
# logic based on chromaDB | |
if ids: | |
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}}) | |
elif filter: | |
for field, value in filter.items(): | |
query["query"]["bool"]["filter"].append( | |
{"term": {f"metadata.{field}": value}} | |
) | |
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) | |
def reset(self): | |
indices = self.client.indices.get(index=f"{self.index_prefix}*") | |
for index in indices: | |
self.client.indices.delete(index=index) | |