chatbot-backend / src /vectorstores /chroma_vectorstore.py
TalatMasood's picture
Log google drive documents in the mongodb, add source of the document and made chunks to overlap text.
acdfaa9
raw
history blame
10.4 kB
# src/vectorstores/chroma_vectorstore.py
import chromadb
from typing import List, Callable, Any, Dict, Optional
from chromadb.config import Settings
import logging
from .base_vectorstore import BaseVectorStore
class ChromaVectorStore(BaseVectorStore):
def __init__(
self,
embedding_function: Callable[[List[str]], List[List[float]]],
persist_directory: str = './chroma_db',
collection_name: str = "documents",
client_settings: Optional[Dict[str, Any]] = None
):
"""
Initialize Chroma Vector Store
Args:
embedding_function (Callable): Function to generate embeddings
persist_directory (str): Directory to persist the vector store
collection_name (str): Name of the collection to use
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client
"""
try:
settings = Settings(
persist_directory=persist_directory,
**(client_settings or {})
)
self.client = chromadb.PersistentClient(settings=settings)
self.collection = self.client.get_or_create_collection(
name=collection_name,
# Using cosine similarity by default
metadata={"hnsw:space": "cosine"}
)
self.embedding_function = embedding_function
except Exception as e:
logging.error(f"Error initializing ChromaDB: {str(e)}")
raise
def add_documents(
self,
documents: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None
) -> None:
"""
Add documents to the vector store
Args:
documents (List[str]): List of document texts
embeddings (Optional[List[List[float]]]): Pre-computed embeddings
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document
ids (Optional[List[str]]): Custom IDs for the documents
"""
try:
if not documents:
logging.warning("No documents provided to add_documents")
return
if not embeddings:
embeddings = self.embedding_function(documents)
if len(documents) != len(embeddings):
raise ValueError(
"Number of documents and embeddings must match")
# Use provided IDs or generate them
doc_ids = ids if ids is not None else [
f"doc_{i}" for i in range(len(documents))]
# Prepare add parameters
add_params = {
"documents": documents,
"embeddings": embeddings,
"ids": doc_ids
}
# Only include metadatas if provided
if metadatas is not None:
if len(metadatas) != len(documents):
raise ValueError(
"Number of documents and metadatas must match")
add_params["metadatas"] = metadatas
self.collection.add(**add_params)
except Exception as e:
logging.error(f"Error adding documents to ChromaDB: {str(e)}")
raise
def similarity_search(
self,
query_embedding: List[float],
top_k: int = 3,
**kwargs
) -> List[Dict[str, Any]]:
"""Perform similarity search with improved chunk handling"""
try:
# Get more initial results to account for sequential chunks
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=max(top_k * 2, 10),
include=['documents', 'metadatas', 'distances']
)
if not results or 'documents' not in results:
return []
formatted_results = []
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
# Group chunks by document_id
doc_chunks = {}
for doc, meta, dist in zip(documents, metadatas, distances):
doc_id = meta.get('document_id')
chunk_index = meta.get('chunk_index', 0)
if doc_id not in doc_chunks:
doc_chunks[doc_id] = []
doc_chunks[doc_id].append({
'text': doc,
'metadata': meta,
'score': 1.0 - dist,
'chunk_index': chunk_index
})
# Process each document's chunks
for doc_id, chunks in doc_chunks.items():
# Sort chunks by index
chunks.sort(key=lambda x: x['chunk_index'])
# Find sequences of chunks with good scores
good_sequences = []
current_sequence = []
for chunk in chunks:
if chunk['score'] > 0.3: # Adjust threshold as needed
if not current_sequence or \
chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1:
current_sequence.append(chunk)
else:
if current_sequence:
good_sequences.append(current_sequence)
current_sequence = [chunk]
else:
if current_sequence:
good_sequences.append(current_sequence)
current_sequence = []
if current_sequence:
good_sequences.append(current_sequence)
# Add best sequences to results
for sequence in good_sequences:
avg_score = sum(c['score']
for c in sequence) / len(sequence)
combined_text = ' '.join(c['text'] for c in sequence)
formatted_results.append({
'text': combined_text,
'metadata': sequence[0]['metadata'],
'score': avg_score
})
# Sort by score and return top_k
formatted_results.sort(key=lambda x: x['score'], reverse=True)
return formatted_results[:top_k]
except Exception as e:
logging.error(f"Error in similarity search: {str(e)}")
raise
def get_all_documents(
self,
include_embeddings: bool = False
) -> List[Dict[str, Any]]:
"""
Retrieve all documents from the vector store
Args:
include_embeddings (bool): Whether to include embeddings in the response
Returns:
List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings
"""
try:
include = ["documents", "metadatas"]
if include_embeddings:
include.append("embeddings")
results = self.collection.get(
include=include
)
if not results or 'documents' not in results:
return []
documents = []
for i in range(len(results['documents'])):
doc = {
'id': str(i), # Generate sequential IDs
'text': results['documents'][i],
}
if include_embeddings and 'embeddings' in results:
doc['embedding'] = results['embeddings'][i]
if 'metadatas' in results and results['metadatas'][i]:
doc['metadata'] = results['metadatas'][i]
# Use document_id from metadata if available
if 'document_id' in results['metadatas'][i]:
doc['id'] = results['metadatas'][i]['document_id']
documents.append(doc)
return documents
except Exception as e:
logging.error(
f"Error retrieving documents from ChromaDB: {str(e)}")
raise
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]:
"""
Retrieve all chunks for a specific document
Args:
document_id (str): ID of the document to retrieve chunks for
Returns:
List[Dict[str, Any]]: List of document chunks with their metadata
"""
try:
results = self.collection.get(
where={"document_id": document_id},
include=["documents", "metadatas"]
)
if not results or 'documents' not in results:
return []
chunks = []
for i in range(len(results['documents'])):
chunk = {
'text': results['documents'][i],
'metadata': results['metadatas'][i] if results.get('metadatas') else None
}
chunks.append(chunk)
# Sort by chunk_index if available
chunks.sort(key=lambda x: x.get(
'metadata', {}).get('chunk_index', 0))
return chunks
except Exception as e:
logging.error(f"Error retrieving document chunks: {str(e)}")
raise
def delete_document(self, document_id: str) -> None:
"""
Delete all chunks associated with a document_id
Args:
document_id (str): ID of the document to delete
"""
try:
# Get all chunks with the given document_id
results = self.collection.get(
where={"document_id": document_id},
include=["metadatas"]
)
if not results or 'ids' not in results:
logging.warning(f"No document found with ID: {document_id}")
return
# Delete all chunks associated with the document
chunk_ids = [
f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))]
self.collection.delete(ids=chunk_ids)
except Exception as e:
logging.error(
f"Error deleting document {document_id} from ChromaDB: {str(e)}")
raise