# src/vectorstores/chroma_vectorstore.py from pathlib import Path import chromadb from typing import List, Callable, Any, Dict, Optional import logging import asyncio from .base_vectorstore import BaseVectorStore from .chroma_manager import ChromaManager class ChromaVectorStore(BaseVectorStore): def __init__( self, embedding_function: Callable[[List[str]], List[List[float]]], persist_directory: str = './chroma_db', collection_name: str = "documents", client_settings: Optional[Dict[str, Any]] = None, client=None # Allow passing an existing client ): """ Initialize Chroma Vector Store Args: embedding_function (Callable): Function to generate embeddings persist_directory (str): Directory to persist the vector store collection_name (str): Name of the collection to use client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client client: Optional existing ChromaDB client to use """ self.embedding_function = embedding_function self.persist_directory = persist_directory self.collection_name = collection_name self.client = client # Store client for later initialization # Will be populated during async initialization self.collection = None self.initialized = False async def initialize(self): """Asynchronously initialize the vector store with enhanced error handling""" if self.initialized: return try: # Get client via manager if not provided if self.client is None: self.client = await ChromaManager.get_client(self.persist_directory) # Validate client if not self.client: raise ValueError("Failed to obtain ChromaDB client") # Get or create collection with more robust handling try: self.collection = await ChromaManager.get_or_create_collection( client=self.client, collection_name=self.collection_name, embedding_dimension=1024 # Default for most models ) except Exception as collection_error: logging.error( f"Error creating collection: {str(collection_error)}") # Try to reset and recreate try: # Attempt to delete existing collection self.client.delete_collection(self.collection_name) except: pass # Recreate collection self.collection = self.client.create_collection( name=self.collection_name, metadata={"hnsw:space": "cosine"} ) # Additional validation if not self.collection: raise ValueError( "Failed to create or obtain ChromaDB collection") self.initialized = True logging.info( f"ChromaVectorStore initialized with collection: {self.collection_name}") except Exception as e: logging.error( f"Critical error initializing ChromaVectorStore: {str(e)}") # Reset initialization state self.initialized = False self.collection = None raise async def _ensure_initialized(self): """Make sure the vector store is initialized before use""" if not self.initialized: await self.initialize() async def add_documents_async( self, documents: List[str], embeddings: Optional[List[List[float]]] = None, metadatas: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None ) -> None: """ Add documents asynchronously with enhanced error handling Args: documents (List[str]): List of document texts embeddings (Optional[List[List[float]]]): Pre-computed embeddings metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document ids (Optional[List[str]]): Custom IDs for the documents """ await self._ensure_initialized() if not documents: logging.warning("No documents provided to add_documents") return # Validate input lists if embeddings and len(documents) != len(embeddings): raise ValueError("Number of documents and embeddings must match") if metadatas and len(documents) != len(metadatas): raise ValueError("Number of documents and metadatas must match") # Generate embeddings if not provided if not embeddings: try: embeddings = self.embedding_function(documents) except Exception as e: logging.error(f"Error generating embeddings: {str(e)}") raise # Use provided IDs or generate them if not ids: ids = [f"doc_{i}" for i in range(len(documents))] # Ensure collection exists and is usable if not self.collection: logging.error("ChromaDB collection is not initialized") await self.initialize() # Prepare add parameters add_params = { "documents": documents, "embeddings": embeddings, "ids": ids } # Add metadatas if provided if metadatas is not None: add_params["metadatas"] = metadatas try: # Add documents to collection with retry mechanism max_retries = 3 for attempt in range(max_retries): try: # Clear any cached state import gc gc.collect() # Attempt to add documents self.collection.add(**add_params) logging.info( f"Successfully added {len(documents)} documents") break except (StopIteration, RuntimeError) as retry_error: if attempt < max_retries - 1: logging.warning( f"Retry attempt {attempt + 1}: {str(retry_error)}") # Optional: Add a small delay between retries await asyncio.sleep(0.5) else: logging.error( f"Failed to add documents after {max_retries} attempts") raise except (StopIteration, RuntimeError) as retry_error: if attempt < max_retries - 1: logging.warning( f"Retry attempt {attempt + 1}: {str(retry_error)}") # Optional: Add a small delay between retries await asyncio.sleep(0.5) else: logging.error( f"Failed to add documents after {max_retries} attempts") raise except Exception as e: logging.error( f"Unexpected error adding documents to ChromaDB: {str(e)}") # Additional debugging information try: logging.info(f"Collection status: {self.collection}") logging.info(f"Documents count: {len(documents)}") logging.info( f"Embeddings count: {len(add_params.get('embeddings', []))}") logging.info( f"Metadatas count: {len(add_params.get('metadatas', []))}") logging.info(f"IDs count: {len(add_params.get('ids', []))}") except Exception as debug_error: logging.error(f"Error during debugging: {str(debug_error)}") def add_documents( self, documents: List[str], embeddings: Optional[List[List[float]]] = None, metadatas: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None ) -> None: """ Synchronous wrapper for add_documents_async """ # Create and run a new event loop if needed try: loop = asyncio.get_event_loop() if loop.is_running(): # Create a future that can be run in the existing loop asyncio.create_task(self.add_documents_async( documents, embeddings, metadatas, ids )) else: # Run in a new event loop loop.run_until_complete(self.add_documents_async( documents, embeddings, metadatas, ids )) except RuntimeError: # No event loop, create a new one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self.add_documents_async( documents, embeddings, metadatas, ids )) async def similarity_search_async( self, query_embedding: List[float], top_k: int = 3, **kwargs ) -> List[Dict[str, Any]]: """ Perform similarity search asynchronously """ await self._ensure_initialized() try: # Get more initial results to account for sequential chunks results = self.collection.query( query_embeddings=[query_embedding], n_results=max(top_k * 2, 10), include=['documents', 'metadatas', 'distances'] ) if not results or 'documents' not in results: return [] formatted_results = [] documents = results['documents'][0] metadatas = results['metadatas'][0] distances = results['distances'][0] # Group chunks by document_id doc_chunks = {} for doc, meta, dist in zip(documents, metadatas, distances): doc_id = meta.get('document_id') chunk_index = meta.get('chunk_index', 0) if doc_id not in doc_chunks: doc_chunks[doc_id] = [] doc_chunks[doc_id].append({ 'text': doc, 'metadata': meta, 'score': 1.0 - dist, 'chunk_index': chunk_index }) # Process each document's chunks for doc_id, chunks in doc_chunks.items(): # Sort chunks by index chunks.sort(key=lambda x: x['chunk_index']) # Find sequences of chunks with good scores good_sequences = [] current_sequence = [] for chunk in chunks: if chunk['score'] > 0.3: # Adjust threshold as needed if not current_sequence or \ chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1: current_sequence.append(chunk) else: if current_sequence: good_sequences.append(current_sequence) current_sequence = [chunk] else: if current_sequence: good_sequences.append(current_sequence) current_sequence = [] if current_sequence: good_sequences.append(current_sequence) # Add best sequences to results for sequence in good_sequences: avg_score = sum(c['score'] for c in sequence) / len(sequence) combined_text = ' '.join(c['text'] for c in sequence) formatted_results.append({ 'text': combined_text, 'metadata': sequence[0]['metadata'], 'score': avg_score }) # Sort by score and return top_k formatted_results.sort(key=lambda x: x['score'], reverse=True) return formatted_results[:top_k] except Exception as e: logging.error(f"Error in similarity search: {str(e)}") raise def similarity_search( self, query_embedding: List[float], top_k: int = 3, **kwargs ) -> List[Dict[str, Any]]: """ Synchronous wrapper for similarity_search_async """ try: loop = asyncio.get_event_loop() if loop.is_running(): # We're in an async context, but need to process directly try: # Get more initial results to account for sequential chunks results = self.collection.query( query_embeddings=[query_embedding], n_results=max(top_k * 2, 10), include=['documents', 'metadatas', 'distances'] ) if not results or 'documents' not in results: return [] formatted_results = [] documents = results['documents'][0] metadatas = results['metadatas'][0] distances = results['distances'][0] # Group chunks by document_id doc_chunks = {} for doc, meta, dist in zip(documents, metadatas, distances): doc_id = meta.get('document_id') chunk_index = meta.get('chunk_index', 0) if doc_id not in doc_chunks: doc_chunks[doc_id] = [] doc_chunks[doc_id].append({ 'text': doc, 'metadata': meta, 'score': 1.0 - dist, 'chunk_index': chunk_index }) # Process each document's chunks for doc_id, chunks in doc_chunks.items(): # Sort chunks by index chunks.sort(key=lambda x: x['chunk_index']) # Find sequences of chunks with good scores good_sequences = [] current_sequence = [] for chunk in chunks: if chunk['score'] > 0.3: # Adjust threshold as needed if not current_sequence or \ chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1: current_sequence.append(chunk) else: if current_sequence: good_sequences.append(current_sequence) current_sequence = [chunk] else: if current_sequence: good_sequences.append(current_sequence) current_sequence = [] if current_sequence: good_sequences.append(current_sequence) # Add best sequences to results for sequence in good_sequences: avg_score = sum(c['score'] for c in sequence) / len(sequence) combined_text = ' '.join( c['text'] for c in sequence) formatted_results.append({ 'text': combined_text, 'metadata': sequence[0]['metadata'], 'score': avg_score }) # Sort by score and return top_k formatted_results.sort( key=lambda x: x['score'], reverse=True) return formatted_results[:top_k] except Exception as e: logging.error( f"Error in direct similarity search: {str(e)}") return [] else: # Run in existing loop return loop.run_until_complete( self.similarity_search_async( query_embedding, top_k, **kwargs) ) except RuntimeError: # No event loop, create a new one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete( self.similarity_search_async(query_embedding, top_k, **kwargs) ) async def get_all_documents_async( self, include_embeddings: bool = False ) -> List[Dict[str, Any]]: """ Retrieve all documents asynchronously """ await self._ensure_initialized() try: include = ["documents", "metadatas"] if include_embeddings: include.append("embeddings") results = self.collection.get( include=include ) if not results or 'documents' not in results: return [] documents = [] for i in range(len(results['documents'])): doc = { 'id': str(i), # Generate sequential IDs 'text': results['documents'][i], } if include_embeddings and 'embeddings' in results: doc['embedding'] = results['embeddings'][i] if 'metadatas' in results and results['metadatas'][i]: doc['metadata'] = results['metadatas'][i] # Use document_id from metadata if available if 'document_id' in results['metadatas'][i]: doc['id'] = results['metadatas'][i]['document_id'] documents.append(doc) return documents except Exception as e: logging.error( f"Error retrieving documents from ChromaDB: {str(e)}") raise def get_all_documents( self, include_embeddings: bool = False ) -> List[Dict[str, Any]]: """ Synchronous wrapper for get_all_documents_async """ try: loop = asyncio.get_event_loop() if loop.is_running(): # We're in an async context, but need to return synchronously # Process the results just like in the async version try: include = ["documents", "metadatas"] if include_embeddings: include.append("embeddings") results = self.collection.get( include=include ) if not results or 'documents' not in results: return [] documents = [] for i in range(len(results['documents'])): doc = { 'id': str(i), # Generate sequential IDs 'text': results['documents'][i], } if include_embeddings and 'embeddings' in results: doc['embedding'] = results['embeddings'][i] if 'metadatas' in results and results['metadatas'][i]: doc['metadata'] = results['metadatas'][i] # Use document_id from metadata if available if 'document_id' in results['metadatas'][i]: doc['id'] = results['metadatas'][i]['document_id'] documents.append(doc) return documents except: return [] else: return loop.run_until_complete( self.get_all_documents_async(include_embeddings) ) except RuntimeError: # No event loop, create a new one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete( self.get_all_documents_async(include_embeddings) ) async def get_document_chunks_async(self, document_id: str) -> List[Dict[str, Any]]: """ Retrieve all chunks for a specific document asynchronously """ await self._ensure_initialized() try: results = self.collection.get( where={"document_id": document_id}, include=["documents", "metadatas"] ) if not results or 'documents' not in results: return [] chunks = [] for i in range(len(results['documents'])): chunk = { 'text': results['documents'][i], 'metadata': results['metadatas'][i] if results.get('metadatas') else None } chunks.append(chunk) # Sort by chunk_index if available chunks.sort(key=lambda x: x.get( 'metadata', {}).get('chunk_index', 0)) return chunks except Exception as e: logging.error(f"Error retrieving document chunks: {str(e)}") raise def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]: """ Synchronous wrapper for get_document_chunks_async """ try: loop = asyncio.get_event_loop() if loop.is_running(): # Fall back to direct query which may fail try: results = self.collection.get( where={"document_id": document_id}, include=["documents", "metadatas"] ) chunks = [] for i in range(len(results['documents'])): chunk = { 'text': results['documents'][i], 'metadata': results['metadatas'][i] if results.get('metadatas') else None } chunks.append(chunk) return chunks except: return [] else: return loop.run_until_complete( self.get_document_chunks_async(document_id) ) except RuntimeError: # No event loop, create a new one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete( self.get_document_chunks_async(document_id) ) async def delete_document_async(self, document_id: str) -> None: """ Delete all chunks associated with a document_id asynchronously """ await self._ensure_initialized() try: # Get all chunks with the given document_id results = self.collection.get( where={"document_id": document_id}, include=["ids"] ) if not results or 'ids' not in results: logging.warning(f"No document found with ID: {document_id}") return # Delete all chunks associated with the document self.collection.delete(ids=results['ids']) except Exception as e: logging.error( f"Error deleting document {document_id} from ChromaDB: {str(e)}") raise def delete_document(self, document_id: str) -> None: """ Synchronous wrapper for delete_document_async """ try: loop = asyncio.get_event_loop() if loop.is_running(): # Create a future that can be run in the existing loop asyncio.create_task(self.delete_document_async(document_id)) else: # Run in a new event loop loop.run_until_complete( self.delete_document_async(document_id)) except RuntimeError: # No event loop, create a new one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self.delete_document_async(document_id))