chatbot-backend / src /vectorstores /chroma_vectorstore.py
TalatMasood's picture
Update knowledge upload api and linked chromadb to mongodb
d161383
raw
history blame
7.83 kB
# src/vectorstores/chroma_vectorstore.py
import chromadb
from typing import List, Callable, Any, Dict, Optional
from chromadb.config import Settings
import logging
from .base_vectorstore import BaseVectorStore
class ChromaVectorStore(BaseVectorStore):
def __init__(
self,
embedding_function: Callable[[List[str]], List[List[float]]],
persist_directory: str = './chroma_db',
collection_name: str = "documents",
client_settings: Optional[Dict[str, Any]] = None
):
"""
Initialize Chroma Vector Store
Args:
embedding_function (Callable): Function to generate embeddings
persist_directory (str): Directory to persist the vector store
collection_name (str): Name of the collection to use
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client
"""
try:
settings = Settings(
persist_directory=persist_directory,
**(client_settings or {})
)
self.client = chromadb.PersistentClient(settings=settings)
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"} # Using cosine similarity by default
)
self.embedding_function = embedding_function
except Exception as e:
logging.error(f"Error initializing ChromaDB: {str(e)}")
raise
def add_documents(
self,
documents: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None
) -> None:
"""
Add documents to the vector store
Args:
documents (List[str]): List of document texts
embeddings (Optional[List[List[float]]]): Pre-computed embeddings
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document
ids (Optional[List[str]]): Custom IDs for the documents
"""
try:
if not documents:
logging.warning("No documents provided to add_documents")
return
if not embeddings:
embeddings = self.embedding_function(documents)
if len(documents) != len(embeddings):
raise ValueError("Number of documents and embeddings must match")
# Use provided IDs or generate them
doc_ids = ids if ids is not None else [f"doc_{i}" for i in range(len(documents))]
# Prepare add parameters
add_params = {
"documents": documents,
"embeddings": embeddings,
"ids": doc_ids
}
# Only include metadatas if provided
if metadatas is not None:
if len(metadatas) != len(documents):
raise ValueError("Number of documents and metadatas must match")
add_params["metadatas"] = metadatas
self.collection.add(**add_params)
except Exception as e:
logging.error(f"Error adding documents to ChromaDB: {str(e)}")
raise
def similarity_search(
self,
query_embedding: List[float],
top_k: int = 3,
**kwargs
) -> List[str]:
"""
Perform similarity search
Args:
query_embedding (List[float]): Embedding of the query
top_k (int): Number of top similar documents to retrieve
**kwargs: Additional search parameters
Returns:
List[str]: List of most similar documents
"""
try:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
**kwargs
)
# Handle the case where no results are found
if not results or 'documents' not in results:
return []
return results.get('documents', [[]])[0]
except Exception as e:
logging.error(f"Error performing similarity search in ChromaDB: {str(e)}")
raise
def get_all_documents(
self,
include_embeddings: bool = False
) -> List[Dict[str, Any]]:
"""
Retrieve all documents from the vector store
"""
try:
include = ["documents", "metadatas"]
if include_embeddings:
include.append("embeddings")
results = self.collection.get(
include=include
)
if not results or 'documents' not in results:
return []
documents = []
for i in range(len(results['documents'])):
doc = {
'id': str(i), # Generate sequential IDs
'text': results['documents'][i],
}
if include_embeddings and 'embeddings' in results:
doc['embedding'] = results['embeddings'][i]
if 'metadatas' in results and results['metadatas'][i]:
doc['metadata'] = results['metadatas'][i]
# Use document_id from metadata if available
if 'document_id' in results['metadatas'][i]:
doc['id'] = results['metadatas'][i]['document_id']
documents.append(doc)
return documents
except Exception as e:
logging.error(f"Error retrieving documents from ChromaDB: {str(e)}")
raise
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]:
"""Retrieve all chunks for a specific document"""
try:
results = self.collection.get(
where={"document_id": document_id},
include=["documents", "metadatas"]
)
if not results or 'documents' not in results:
return []
chunks = []
for i in range(len(results['documents'])):
chunk = {
'text': results['documents'][i],
'metadata': results['metadatas'][i] if results.get('metadatas') else None
}
chunks.append(chunk)
# Sort by chunk_index if available
chunks.sort(key=lambda x: x.get('metadata', {}).get('chunk_index', 0))
return chunks
except Exception as e:
logging.error(f"Error retrieving document chunks: {str(e)}")
raise
def delete_document(self, document_id: str) -> None:
"""Delete all chunks associated with a document_id"""
try:
# Get all chunks with the given document_id
results = self.collection.get(
where={"document_id": document_id},
include=["metadatas"]
)
if not results or 'ids' not in results:
logging.warning(f"No document found with ID: {document_id}")
return
# Delete all chunks associated with the document
chunk_ids = [f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))]
self.collection.delete(ids=chunk_ids)
except Exception as e:
logging.error(f"Error deleting document {document_id} from ChromaDB: {str(e)}")
raise