chatbot-backend / src /vectorstores /chroma_vectorstore.py
TalatMasood's picture
Updarte chatbot with deployment configurations on the Render
415595f
raw
history blame
10.1 kB
# src/vectorstores/chroma_vectorstore.py
import chromadb
from typing import List, Callable, Any, Dict, Optional
from chromadb.config import Settings
import logging
from .base_vectorstore import BaseVectorStore
class ChromaVectorStore(BaseVectorStore):
def __init__(
self,
embedding_function: Callable[[List[str]], List[List[float]]],
persist_directory: str = './chroma_db',
collection_name: str = "documents",
client_settings: Optional[Dict[str, Any]] = None
):
"""
Initialize Chroma Vector Store
Args:
embedding_function (Callable): Function to generate embeddings
persist_directory (str): Directory to persist the vector store
collection_name (str): Name of the collection to use
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client
"""
try:
settings = Settings(
persist_directory=persist_directory,
**(client_settings or {})
)
self.client = chromadb.PersistentClient(settings=settings)
self.collection = self.client.get_or_create_collection(
name=collection_name,
# Using cosine similarity by default
metadata={"hnsw:space": "cosine"}
)
self.embedding_function = embedding_function
except Exception as e:
logging.error(f"Error initializing ChromaDB: {str(e)}")
raise
def add_documents(
self,
documents: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None
) -> None:
"""
Add documents to the vector store
Args:
documents (List[str]): List of document texts
embeddings (Optional[List[List[float]]]): Pre-computed embeddings
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document
ids (Optional[List[str]]): Custom IDs for the documents
"""
try:
if not documents:
logging.warning("No documents provided to add_documents")
return
if not embeddings:
embeddings = self.embedding_function(documents)
if len(documents) != len(embeddings):
raise ValueError(
"Number of documents and embeddings must match")
# Use provided IDs or generate them
doc_ids = ids if ids is not None else [
f"doc_{i}" for i in range(len(documents))]
# Prepare add parameters
add_params = {
"documents": documents,
"embeddings": embeddings,
"ids": doc_ids
}
# Only include metadatas if provided
if metadatas is not None:
if len(metadatas) != len(documents):
raise ValueError(
"Number of documents and metadatas must match")
add_params["metadatas"] = metadatas
self.collection.add(**add_params)
except Exception as e:
logging.error(f"Error adding documents to ChromaDB: {str(e)}")
raise
def similarity_search(
self,
query_embedding: List[float],
top_k: int = 3,
**kwargs
) -> List[Dict[str, Any]]:
"""
Perform similarity search with improved matching
"""
try:
# Increase n_results to get more potential matches
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=10, # Get more initial results
include=['documents', 'metadatas', 'distances']
)
if not results or 'documents' not in results or not results['documents']:
logging.warning("No results found in similarity search")
return []
formatted_results = []
documents = results['documents'][0] # First query's results
metadatas = results['metadatas'][0] if results.get('metadatas') else [
None] * len(documents)
distances = results['distances'][0] if results.get('distances') else [
None] * len(documents)
# Process all results
for doc, meta, dist in zip(documents, metadatas, distances):
# Convert distance to similarity score (1 is most similar, 0 is least)
similarity_score = 1.0 - \
(dist or 0.0) if dist is not None else None
# More permissive threshold and include all results for filtering
if similarity_score is not None and similarity_score > 0.2: # Lower threshold
formatted_results.append({
'text': doc,
'metadata': meta or {},
'score': similarity_score
})
# Sort by score and get top_k results
formatted_results.sort(key=lambda x: x['score'] or 0, reverse=True)
# Check if results are from same document and get consecutive chunks
if formatted_results:
first_doc_id = formatted_results[0]['metadata'].get(
'document_id')
all_chunks_same_doc = []
# Get all chunks from the same document
for result in formatted_results:
if result['metadata'].get('document_id') == first_doc_id:
all_chunks_same_doc.append(result)
# Sort chunks by their index to maintain document flow
all_chunks_same_doc.sort(
key=lambda x: x['metadata'].get('chunk_index', 0)
)
# Return either all chunks from same document or top_k results
if len(all_chunks_same_doc) > 0:
return all_chunks_same_doc[:top_k]
return formatted_results[:top_k]
except Exception as e:
logging.error(
f"Error performing similarity search in ChromaDB: {str(e)}")
raise
def get_all_documents(
self,
include_embeddings: bool = False
) -> List[Dict[str, Any]]:
"""
Retrieve all documents from the vector store
Args:
include_embeddings (bool): Whether to include embeddings in the response
Returns:
List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings
"""
try:
include = ["documents", "metadatas"]
if include_embeddings:
include.append("embeddings")
results = self.collection.get(
include=include
)
if not results or 'documents' not in results:
return []
documents = []
for i in range(len(results['documents'])):
doc = {
'id': str(i), # Generate sequential IDs
'text': results['documents'][i],
}
if include_embeddings and 'embeddings' in results:
doc['embedding'] = results['embeddings'][i]
if 'metadatas' in results and results['metadatas'][i]:
doc['metadata'] = results['metadatas'][i]
# Use document_id from metadata if available
if 'document_id' in results['metadatas'][i]:
doc['id'] = results['metadatas'][i]['document_id']
documents.append(doc)
return documents
except Exception as e:
logging.error(
f"Error retrieving documents from ChromaDB: {str(e)}")
raise
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]:
"""
Retrieve all chunks for a specific document
Args:
document_id (str): ID of the document to retrieve chunks for
Returns:
List[Dict[str, Any]]: List of document chunks with their metadata
"""
try:
results = self.collection.get(
where={"document_id": document_id},
include=["documents", "metadatas"]
)
if not results or 'documents' not in results:
return []
chunks = []
for i in range(len(results['documents'])):
chunk = {
'text': results['documents'][i],
'metadata': results['metadatas'][i] if results.get('metadatas') else None
}
chunks.append(chunk)
# Sort by chunk_index if available
chunks.sort(key=lambda x: x.get(
'metadata', {}).get('chunk_index', 0))
return chunks
except Exception as e:
logging.error(f"Error retrieving document chunks: {str(e)}")
raise
def delete_document(self, document_id: str) -> None:
"""
Delete all chunks associated with a document_id
Args:
document_id (str): ID of the document to delete
"""
try:
# Get all chunks with the given document_id
results = self.collection.get(
where={"document_id": document_id},
include=["metadatas"]
)
if not results or 'ids' not in results:
logging.warning(f"No document found with ID: {document_id}")
return
# Delete all chunks associated with the document
chunk_ids = [
f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))]
self.collection.delete(ids=chunk_ids)
except Exception as e:
logging.error(
f"Error deleting document {document_id} from ChromaDB: {str(e)}")
raise