|
from elasticsearch import Elasticsearch, BadRequestError |
|
from typing import Optional |
|
import ssl |
|
from elasticsearch.helpers import bulk, scan |
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult |
|
from open_webui.config import ( |
|
ELASTICSEARCH_URL, |
|
ELASTICSEARCH_CA_CERTS, |
|
ELASTICSEARCH_API_KEY, |
|
ELASTICSEARCH_USERNAME, |
|
ELASTICSEARCH_PASSWORD, |
|
ELASTICSEARCH_CLOUD_ID, |
|
ELASTICSEARCH_INDEX_PREFIX, |
|
SSL_ASSERT_FINGERPRINT, |
|
) |
|
|
|
|
|
class ElasticsearchClient: |
|
""" |
|
Important: |
|
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating |
|
an index for each file but store it as a text field, while seperating to different index |
|
baesd on the embedding length. |
|
""" |
|
|
|
def __init__(self): |
|
self.index_prefix = ELASTICSEARCH_INDEX_PREFIX |
|
self.client = Elasticsearch( |
|
hosts=[ELASTICSEARCH_URL], |
|
ca_certs=ELASTICSEARCH_CA_CERTS, |
|
api_key=ELASTICSEARCH_API_KEY, |
|
cloud_id=ELASTICSEARCH_CLOUD_ID, |
|
basic_auth=( |
|
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD) |
|
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD |
|
else None |
|
), |
|
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT, |
|
) |
|
|
|
|
|
def _get_index_name(self, dimension: int) -> str: |
|
return f"{self.index_prefix}_d{str(dimension)}" |
|
|
|
|
|
def _scan_result_to_get_result(self, result) -> GetResult: |
|
if not result: |
|
return None |
|
ids = [] |
|
documents = [] |
|
metadatas = [] |
|
|
|
for hit in result: |
|
ids.append(hit["_id"]) |
|
documents.append(hit["_source"].get("text")) |
|
metadatas.append(hit["_source"].get("metadata")) |
|
|
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) |
|
|
|
|
|
def _result_to_get_result(self, result) -> GetResult: |
|
if not result["hits"]["hits"]: |
|
return None |
|
ids = [] |
|
documents = [] |
|
metadatas = [] |
|
|
|
for hit in result["hits"]["hits"]: |
|
ids.append(hit["_id"]) |
|
documents.append(hit["_source"].get("text")) |
|
metadatas.append(hit["_source"].get("metadata")) |
|
|
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) |
|
|
|
|
|
def _result_to_search_result(self, result) -> SearchResult: |
|
ids = [] |
|
distances = [] |
|
documents = [] |
|
metadatas = [] |
|
|
|
for hit in result["hits"]["hits"]: |
|
ids.append(hit["_id"]) |
|
distances.append(hit["_score"]) |
|
documents.append(hit["_source"].get("text")) |
|
metadatas.append(hit["_source"].get("metadata")) |
|
|
|
return SearchResult( |
|
ids=[ids], |
|
distances=[distances], |
|
documents=[documents], |
|
metadatas=[metadatas], |
|
) |
|
|
|
|
|
def _create_index(self, dimension: int): |
|
body = { |
|
"mappings": { |
|
"dynamic_templates": [ |
|
{ |
|
"strings": { |
|
"match_mapping_type": "string", |
|
"mapping": {"type": "keyword"}, |
|
} |
|
} |
|
], |
|
"properties": { |
|
"collection": {"type": "keyword"}, |
|
"id": {"type": "keyword"}, |
|
"vector": { |
|
"type": "dense_vector", |
|
"dims": dimension, |
|
"index": True, |
|
"similarity": "cosine", |
|
}, |
|
"text": {"type": "text"}, |
|
"metadata": {"type": "object"}, |
|
}, |
|
} |
|
} |
|
self.client.indices.create(index=self._get_index_name(dimension), body=body) |
|
|
|
|
|
|
|
def _create_batches(self, items: list[VectorItem], batch_size=100): |
|
for i in range(0, len(items), batch_size): |
|
yield items[i : min(i + batch_size, len(items))] |
|
|
|
|
|
def has_collection(self, collection_name) -> bool: |
|
query_body = {"query": {"bool": {"filter": []}}} |
|
query_body["query"]["bool"]["filter"].append( |
|
{"term": {"collection": collection_name}} |
|
) |
|
|
|
try: |
|
result = self.client.count(index=f"{self.index_prefix}*", body=query_body) |
|
|
|
return result.body["count"] > 0 |
|
except Exception as e: |
|
return None |
|
|
|
def delete_collection(self, collection_name: str): |
|
query = {"query": {"term": {"collection": collection_name}}} |
|
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) |
|
|
|
|
|
def search( |
|
self, collection_name: str, vectors: list[list[float]], limit: int |
|
) -> Optional[SearchResult]: |
|
query = { |
|
"size": limit, |
|
"_source": ["text", "metadata"], |
|
"query": { |
|
"script_score": { |
|
"query": { |
|
"bool": {"filter": [{"term": {"collection": collection_name}}]} |
|
}, |
|
"script": { |
|
"source": "cosineSimilarity(params.vector, 'vector') + 1.0", |
|
"params": { |
|
"vector": vectors[0] |
|
}, |
|
}, |
|
} |
|
}, |
|
} |
|
|
|
result = self.client.search( |
|
index=self._get_index_name(len(vectors[0])), body=query |
|
) |
|
|
|
return self._result_to_search_result(result) |
|
|
|
|
|
def query( |
|
self, collection_name: str, filter: dict, limit: Optional[int] = None |
|
) -> Optional[GetResult]: |
|
if not self.has_collection(collection_name): |
|
return None |
|
|
|
query_body = { |
|
"query": {"bool": {"filter": []}}, |
|
"_source": ["text", "metadata"], |
|
} |
|
|
|
for field, value in filter.items(): |
|
query_body["query"]["bool"]["filter"].append({"term": {field: value}}) |
|
query_body["query"]["bool"]["filter"].append( |
|
{"term": {"collection": collection_name}} |
|
) |
|
size = limit if limit else 10 |
|
|
|
try: |
|
result = self.client.search( |
|
index=f"{self.index_prefix}*", |
|
body=query_body, |
|
size=size, |
|
) |
|
|
|
return self._result_to_get_result(result) |
|
|
|
except Exception as e: |
|
return None |
|
|
|
|
|
def _has_index(self, dimension: int): |
|
return self.client.indices.exists( |
|
index=self._get_index_name(dimension=dimension) |
|
) |
|
|
|
def get_or_create_index(self, dimension: int): |
|
if not self._has_index(dimension=dimension): |
|
self._create_index(dimension=dimension) |
|
|
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]: |
|
|
|
query = { |
|
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}, |
|
"_source": ["text", "metadata"], |
|
} |
|
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query)) |
|
|
|
return self._scan_result_to_get_result(results) |
|
|
|
|
|
def insert(self, collection_name: str, items: list[VectorItem]): |
|
if not self._has_index(dimension=len(items[0]["vector"])): |
|
self._create_index(dimension=len(items[0]["vector"])) |
|
|
|
for batch in self._create_batches(items): |
|
actions = [ |
|
{ |
|
"_index": self._get_index_name(dimension=len(items[0]["vector"])), |
|
"_id": item["id"], |
|
"_source": { |
|
"collection": collection_name, |
|
"vector": item["vector"], |
|
"text": item["text"], |
|
"metadata": item["metadata"], |
|
}, |
|
} |
|
for item in batch |
|
] |
|
bulk(self.client, actions) |
|
|
|
|
|
def upsert(self, collection_name: str, items: list[VectorItem]): |
|
if not self._has_index(dimension=len(items[0]["vector"])): |
|
self._create_index(dimension=len(items[0]["vector"])) |
|
for batch in self._create_batches(items): |
|
actions = [ |
|
{ |
|
"_op_type": "update", |
|
"_index": self._get_index_name(dimension=len(item["vector"])), |
|
"_id": item["id"], |
|
"doc": { |
|
"collection": collection_name, |
|
"vector": item["vector"], |
|
"text": item["text"], |
|
"metadata": item["metadata"], |
|
}, |
|
"doc_as_upsert": True, |
|
} |
|
for item in batch |
|
] |
|
bulk(self.client, actions) |
|
|
|
|
|
def delete( |
|
self, |
|
collection_name: str, |
|
ids: Optional[list[str]] = None, |
|
filter: Optional[dict] = None, |
|
): |
|
|
|
query = { |
|
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}} |
|
} |
|
|
|
if ids: |
|
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}}) |
|
elif filter: |
|
for field, value in filter.items(): |
|
query["query"]["bool"]["filter"].append( |
|
{"term": {f"metadata.{field}": value}} |
|
) |
|
|
|
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) |
|
|
|
def reset(self): |
|
indices = self.client.indices.get(index=f"{self.index_prefix}*") |
|
for index in indices: |
|
self.client.indices.delete(index=index) |
|
|