chatbot-backend / src /vectorstores /chroma_vectorstore.py
TalatMasood's picture
Updating chroma db to be singleton class
6082154
raw
history blame
24.9 kB
# src/vectorstores/chroma_vectorstore.py
from pathlib import Path
import chromadb
from typing import List, Callable, Any, Dict, Optional
import logging
import asyncio
from .base_vectorstore import BaseVectorStore
from .chroma_manager import ChromaManager
class ChromaVectorStore(BaseVectorStore):
def __init__(
self,
embedding_function: Callable[[List[str]], List[List[float]]],
persist_directory: str = './chroma_db',
collection_name: str = "documents",
client_settings: Optional[Dict[str, Any]] = None,
client=None # Allow passing an existing client
):
"""
Initialize Chroma Vector Store
Args:
embedding_function (Callable): Function to generate embeddings
persist_directory (str): Directory to persist the vector store
collection_name (str): Name of the collection to use
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client
client: Optional existing ChromaDB client to use
"""
self.embedding_function = embedding_function
self.persist_directory = persist_directory
self.collection_name = collection_name
self.client = client # Store client for later initialization
# Will be populated during async initialization
self.collection = None
self.initialized = False
async def initialize(self):
"""Asynchronously initialize the vector store with enhanced error handling"""
if self.initialized:
return
try:
# Get client via manager if not provided
if self.client is None:
self.client = await ChromaManager.get_client(self.persist_directory)
# Validate client
if not self.client:
raise ValueError("Failed to obtain ChromaDB client")
# Get or create collection with more robust handling
try:
self.collection = await ChromaManager.get_or_create_collection(
client=self.client,
collection_name=self.collection_name,
embedding_dimension=1024 # Default for most models
)
except Exception as collection_error:
logging.error(
f"Error creating collection: {str(collection_error)}")
# Try to reset and recreate
try:
# Attempt to delete existing collection
self.client.delete_collection(self.collection_name)
except:
pass
# Recreate collection
self.collection = self.client.create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
# Additional validation
if not self.collection:
raise ValueError(
"Failed to create or obtain ChromaDB collection")
self.initialized = True
logging.info(
f"ChromaVectorStore initialized with collection: {self.collection_name}")
except Exception as e:
logging.error(
f"Critical error initializing ChromaVectorStore: {str(e)}")
# Reset initialization state
self.initialized = False
self.collection = None
raise
async def _ensure_initialized(self):
"""Make sure the vector store is initialized before use"""
if not self.initialized:
await self.initialize()
async def add_documents_async(
self,
documents: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None
) -> None:
"""
Add documents asynchronously with enhanced error handling
Args:
documents (List[str]): List of document texts
embeddings (Optional[List[List[float]]]): Pre-computed embeddings
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document
ids (Optional[List[str]]): Custom IDs for the documents
"""
await self._ensure_initialized()
if not documents:
logging.warning("No documents provided to add_documents")
return
# Validate input lists
if embeddings and len(documents) != len(embeddings):
raise ValueError("Number of documents and embeddings must match")
if metadatas and len(documents) != len(metadatas):
raise ValueError("Number of documents and metadatas must match")
# Generate embeddings if not provided
if not embeddings:
try:
embeddings = self.embedding_function(documents)
except Exception as e:
logging.error(f"Error generating embeddings: {str(e)}")
raise
# Use provided IDs or generate them
if not ids:
ids = [f"doc_{i}" for i in range(len(documents))]
# Ensure collection exists and is usable
if not self.collection:
logging.error("ChromaDB collection is not initialized")
await self.initialize()
# Prepare add parameters
add_params = {
"documents": documents,
"embeddings": embeddings,
"ids": ids
}
# Add metadatas if provided
if metadatas is not None:
add_params["metadatas"] = metadatas
try:
# Add documents to collection with retry mechanism
max_retries = 3
for attempt in range(max_retries):
try:
# Clear any cached state
import gc
gc.collect()
# Attempt to add documents
self.collection.add(**add_params)
logging.info(
f"Successfully added {len(documents)} documents")
break
except (StopIteration, RuntimeError) as retry_error:
if attempt < max_retries - 1:
logging.warning(
f"Retry attempt {attempt + 1}: {str(retry_error)}")
# Optional: Add a small delay between retries
await asyncio.sleep(0.5)
else:
logging.error(
f"Failed to add documents after {max_retries} attempts")
raise
except (StopIteration, RuntimeError) as retry_error:
if attempt < max_retries - 1:
logging.warning(
f"Retry attempt {attempt + 1}: {str(retry_error)}")
# Optional: Add a small delay between retries
await asyncio.sleep(0.5)
else:
logging.error(
f"Failed to add documents after {max_retries} attempts")
raise
except Exception as e:
logging.error(
f"Unexpected error adding documents to ChromaDB: {str(e)}")
# Additional debugging information
try:
logging.info(f"Collection status: {self.collection}")
logging.info(f"Documents count: {len(documents)}")
logging.info(
f"Embeddings count: {len(add_params.get('embeddings', []))}")
logging.info(
f"Metadatas count: {len(add_params.get('metadatas', []))}")
logging.info(f"IDs count: {len(add_params.get('ids', []))}")
except Exception as debug_error:
logging.error(f"Error during debugging: {str(debug_error)}")
def add_documents(
self,
documents: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None
) -> None:
"""
Synchronous wrapper for add_documents_async
"""
# Create and run a new event loop if needed
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Create a future that can be run in the existing loop
asyncio.create_task(self.add_documents_async(
documents, embeddings, metadatas, ids
))
else:
# Run in a new event loop
loop.run_until_complete(self.add_documents_async(
documents, embeddings, metadatas, ids
))
except RuntimeError:
# No event loop, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.add_documents_async(
documents, embeddings, metadatas, ids
))
async def similarity_search_async(
self,
query_embedding: List[float],
top_k: int = 3,
**kwargs
) -> List[Dict[str, Any]]:
"""
Perform similarity search asynchronously
"""
await self._ensure_initialized()
try:
# Get more initial results to account for sequential chunks
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=max(top_k * 2, 10),
include=['documents', 'metadatas', 'distances']
)
if not results or 'documents' not in results:
return []
formatted_results = []
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
# Group chunks by document_id
doc_chunks = {}
for doc, meta, dist in zip(documents, metadatas, distances):
doc_id = meta.get('document_id')
chunk_index = meta.get('chunk_index', 0)
if doc_id not in doc_chunks:
doc_chunks[doc_id] = []
doc_chunks[doc_id].append({
'text': doc,
'metadata': meta,
'score': 1.0 - dist,
'chunk_index': chunk_index
})
# Process each document's chunks
for doc_id, chunks in doc_chunks.items():
# Sort chunks by index
chunks.sort(key=lambda x: x['chunk_index'])
# Find sequences of chunks with good scores
good_sequences = []
current_sequence = []
for chunk in chunks:
if chunk['score'] > 0.3: # Adjust threshold as needed
if not current_sequence or \
chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1:
current_sequence.append(chunk)
else:
if current_sequence:
good_sequences.append(current_sequence)
current_sequence = [chunk]
else:
if current_sequence:
good_sequences.append(current_sequence)
current_sequence = []
if current_sequence:
good_sequences.append(current_sequence)
# Add best sequences to results
for sequence in good_sequences:
avg_score = sum(c['score']
for c in sequence) / len(sequence)
combined_text = ' '.join(c['text'] for c in sequence)
formatted_results.append({
'text': combined_text,
'metadata': sequence[0]['metadata'],
'score': avg_score
})
# Sort by score and return top_k
formatted_results.sort(key=lambda x: x['score'], reverse=True)
return formatted_results[:top_k]
except Exception as e:
logging.error(f"Error in similarity search: {str(e)}")
raise
def similarity_search(
self,
query_embedding: List[float],
top_k: int = 3,
**kwargs
) -> List[Dict[str, Any]]:
"""
Synchronous wrapper for similarity_search_async
"""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, but need to process directly
try:
# Get more initial results to account for sequential chunks
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=max(top_k * 2, 10),
include=['documents', 'metadatas', 'distances']
)
if not results or 'documents' not in results:
return []
formatted_results = []
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
# Group chunks by document_id
doc_chunks = {}
for doc, meta, dist in zip(documents, metadatas, distances):
doc_id = meta.get('document_id')
chunk_index = meta.get('chunk_index', 0)
if doc_id not in doc_chunks:
doc_chunks[doc_id] = []
doc_chunks[doc_id].append({
'text': doc,
'metadata': meta,
'score': 1.0 - dist,
'chunk_index': chunk_index
})
# Process each document's chunks
for doc_id, chunks in doc_chunks.items():
# Sort chunks by index
chunks.sort(key=lambda x: x['chunk_index'])
# Find sequences of chunks with good scores
good_sequences = []
current_sequence = []
for chunk in chunks:
if chunk['score'] > 0.3: # Adjust threshold as needed
if not current_sequence or \
chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1:
current_sequence.append(chunk)
else:
if current_sequence:
good_sequences.append(current_sequence)
current_sequence = [chunk]
else:
if current_sequence:
good_sequences.append(current_sequence)
current_sequence = []
if current_sequence:
good_sequences.append(current_sequence)
# Add best sequences to results
for sequence in good_sequences:
avg_score = sum(c['score']
for c in sequence) / len(sequence)
combined_text = ' '.join(
c['text'] for c in sequence)
formatted_results.append({
'text': combined_text,
'metadata': sequence[0]['metadata'],
'score': avg_score
})
# Sort by score and return top_k
formatted_results.sort(
key=lambda x: x['score'], reverse=True)
return formatted_results[:top_k]
except Exception as e:
logging.error(
f"Error in direct similarity search: {str(e)}")
return []
else:
# Run in existing loop
return loop.run_until_complete(
self.similarity_search_async(
query_embedding, top_k, **kwargs)
)
except RuntimeError:
# No event loop, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self.similarity_search_async(query_embedding, top_k, **kwargs)
)
async def get_all_documents_async(
self,
include_embeddings: bool = False
) -> List[Dict[str, Any]]:
"""
Retrieve all documents asynchronously
"""
await self._ensure_initialized()
try:
include = ["documents", "metadatas"]
if include_embeddings:
include.append("embeddings")
results = self.collection.get(
include=include
)
if not results or 'documents' not in results:
return []
documents = []
for i in range(len(results['documents'])):
doc = {
'id': str(i), # Generate sequential IDs
'text': results['documents'][i],
}
if include_embeddings and 'embeddings' in results:
doc['embedding'] = results['embeddings'][i]
if 'metadatas' in results and results['metadatas'][i]:
doc['metadata'] = results['metadatas'][i]
# Use document_id from metadata if available
if 'document_id' in results['metadatas'][i]:
doc['id'] = results['metadatas'][i]['document_id']
documents.append(doc)
return documents
except Exception as e:
logging.error(
f"Error retrieving documents from ChromaDB: {str(e)}")
raise
def get_all_documents(
self,
include_embeddings: bool = False
) -> List[Dict[str, Any]]:
"""
Synchronous wrapper for get_all_documents_async
"""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, but need to return synchronously
# Process the results just like in the async version
try:
include = ["documents", "metadatas"]
if include_embeddings:
include.append("embeddings")
results = self.collection.get(
include=include
)
if not results or 'documents' not in results:
return []
documents = []
for i in range(len(results['documents'])):
doc = {
'id': str(i), # Generate sequential IDs
'text': results['documents'][i],
}
if include_embeddings and 'embeddings' in results:
doc['embedding'] = results['embeddings'][i]
if 'metadatas' in results and results['metadatas'][i]:
doc['metadata'] = results['metadatas'][i]
# Use document_id from metadata if available
if 'document_id' in results['metadatas'][i]:
doc['id'] = results['metadatas'][i]['document_id']
documents.append(doc)
return documents
except:
return []
else:
return loop.run_until_complete(
self.get_all_documents_async(include_embeddings)
)
except RuntimeError:
# No event loop, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self.get_all_documents_async(include_embeddings)
)
async def get_document_chunks_async(self, document_id: str) -> List[Dict[str, Any]]:
"""
Retrieve all chunks for a specific document asynchronously
"""
await self._ensure_initialized()
try:
results = self.collection.get(
where={"document_id": document_id},
include=["documents", "metadatas"]
)
if not results or 'documents' not in results:
return []
chunks = []
for i in range(len(results['documents'])):
chunk = {
'text': results['documents'][i],
'metadata': results['metadatas'][i] if results.get('metadatas') else None
}
chunks.append(chunk)
# Sort by chunk_index if available
chunks.sort(key=lambda x: x.get(
'metadata', {}).get('chunk_index', 0))
return chunks
except Exception as e:
logging.error(f"Error retrieving document chunks: {str(e)}")
raise
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]:
"""
Synchronous wrapper for get_document_chunks_async
"""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Fall back to direct query which may fail
try:
results = self.collection.get(
where={"document_id": document_id},
include=["documents", "metadatas"]
)
chunks = []
for i in range(len(results['documents'])):
chunk = {
'text': results['documents'][i],
'metadata': results['metadatas'][i] if results.get('metadatas') else None
}
chunks.append(chunk)
return chunks
except:
return []
else:
return loop.run_until_complete(
self.get_document_chunks_async(document_id)
)
except RuntimeError:
# No event loop, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self.get_document_chunks_async(document_id)
)
async def delete_document_async(self, document_id: str) -> None:
"""
Delete all chunks associated with a document_id asynchronously
"""
await self._ensure_initialized()
try:
# Get all chunks with the given document_id
results = self.collection.get(
where={"document_id": document_id},
include=["ids"]
)
if not results or 'ids' not in results:
logging.warning(f"No document found with ID: {document_id}")
return
# Delete all chunks associated with the document
self.collection.delete(ids=results['ids'])
except Exception as e:
logging.error(
f"Error deleting document {document_id} from ChromaDB: {str(e)}")
raise
def delete_document(self, document_id: str) -> None:
"""
Synchronous wrapper for delete_document_async
"""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Create a future that can be run in the existing loop
asyncio.create_task(self.delete_document_async(document_id))
else:
# Run in a new event loop
loop.run_until_complete(
self.delete_document_async(document_id))
except RuntimeError:
# No event loop, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.delete_document_async(document_id))