# src/vectorstores/chroma_vectorstore.py import chromadb from typing import List, Callable, Any, Dict, Optional from chromadb.config import Settings import logging from .base_vectorstore import BaseVectorStore class ChromaVectorStore(BaseVectorStore): def __init__( self, embedding_function: Callable[[List[str]], List[List[float]]], persist_directory: str = './chroma_db', collection_name: str = "documents", client_settings: Optional[Dict[str, Any]] = None ): """ Initialize Chroma Vector Store Args: embedding_function (Callable): Function to generate embeddings persist_directory (str): Directory to persist the vector store collection_name (str): Name of the collection to use client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client """ try: settings = Settings( persist_directory=persist_directory, **(client_settings or {}) ) self.client = chromadb.PersistentClient(settings=settings) self.collection = self.client.get_or_create_collection( name=collection_name, # Using cosine similarity by default metadata={"hnsw:space": "cosine"} ) self.embedding_function = embedding_function except Exception as e: logging.error(f"Error initializing ChromaDB: {str(e)}") raise def add_documents( self, documents: List[str], embeddings: Optional[List[List[float]]] = None, metadatas: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None ) -> None: """ Add documents to the vector store Args: documents (List[str]): List of document texts embeddings (Optional[List[List[float]]]): Pre-computed embeddings metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document ids (Optional[List[str]]): Custom IDs for the documents """ try: if not documents: logging.warning("No documents provided to add_documents") return if not embeddings: embeddings = self.embedding_function(documents) if len(documents) != len(embeddings): raise ValueError( "Number of documents and embeddings must match") # Use provided IDs or generate them doc_ids = ids if ids is not None else [ f"doc_{i}" for i in range(len(documents))] # Prepare add parameters add_params = { "documents": documents, "embeddings": embeddings, "ids": doc_ids } # Only include metadatas if provided if metadatas is not None: if len(metadatas) != len(documents): raise ValueError( "Number of documents and metadatas must match") add_params["metadatas"] = metadatas self.collection.add(**add_params) except Exception as e: logging.error(f"Error adding documents to ChromaDB: {str(e)}") raise def similarity_search( self, query_embedding: List[float], top_k: int = 3, **kwargs ) -> List[Dict[str, Any]]: """Perform similarity search with improved chunk handling""" try: # Get more initial results to account for sequential chunks results = self.collection.query( query_embeddings=[query_embedding], n_results=max(top_k * 2, 10), include=['documents', 'metadatas', 'distances'] ) if not results or 'documents' not in results: return [] formatted_results = [] documents = results['documents'][0] metadatas = results['metadatas'][0] distances = results['distances'][0] # Group chunks by document_id doc_chunks = {} for doc, meta, dist in zip(documents, metadatas, distances): doc_id = meta.get('document_id') chunk_index = meta.get('chunk_index', 0) if doc_id not in doc_chunks: doc_chunks[doc_id] = [] doc_chunks[doc_id].append({ 'text': doc, 'metadata': meta, 'score': 1.0 - dist, 'chunk_index': chunk_index }) # Process each document's chunks for doc_id, chunks in doc_chunks.items(): # Sort chunks by index chunks.sort(key=lambda x: x['chunk_index']) # Find sequences of chunks with good scores good_sequences = [] current_sequence = [] for chunk in chunks: if chunk['score'] > 0.3: # Adjust threshold as needed if not current_sequence or \ chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1: current_sequence.append(chunk) else: if current_sequence: good_sequences.append(current_sequence) current_sequence = [chunk] else: if current_sequence: good_sequences.append(current_sequence) current_sequence = [] if current_sequence: good_sequences.append(current_sequence) # Add best sequences to results for sequence in good_sequences: avg_score = sum(c['score'] for c in sequence) / len(sequence) combined_text = ' '.join(c['text'] for c in sequence) formatted_results.append({ 'text': combined_text, 'metadata': sequence[0]['metadata'], 'score': avg_score }) # Sort by score and return top_k formatted_results.sort(key=lambda x: x['score'], reverse=True) return formatted_results[:top_k] except Exception as e: logging.error(f"Error in similarity search: {str(e)}") raise def get_all_documents( self, include_embeddings: bool = False ) -> List[Dict[str, Any]]: """ Retrieve all documents from the vector store Args: include_embeddings (bool): Whether to include embeddings in the response Returns: List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings """ try: include = ["documents", "metadatas"] if include_embeddings: include.append("embeddings") results = self.collection.get( include=include ) if not results or 'documents' not in results: return [] documents = [] for i in range(len(results['documents'])): doc = { 'id': str(i), # Generate sequential IDs 'text': results['documents'][i], } if include_embeddings and 'embeddings' in results: doc['embedding'] = results['embeddings'][i] if 'metadatas' in results and results['metadatas'][i]: doc['metadata'] = results['metadatas'][i] # Use document_id from metadata if available if 'document_id' in results['metadatas'][i]: doc['id'] = results['metadatas'][i]['document_id'] documents.append(doc) return documents except Exception as e: logging.error( f"Error retrieving documents from ChromaDB: {str(e)}") raise def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]: """ Retrieve all chunks for a specific document Args: document_id (str): ID of the document to retrieve chunks for Returns: List[Dict[str, Any]]: List of document chunks with their metadata """ try: results = self.collection.get( where={"document_id": document_id}, include=["documents", "metadatas"] ) if not results or 'documents' not in results: return [] chunks = [] for i in range(len(results['documents'])): chunk = { 'text': results['documents'][i], 'metadata': results['metadatas'][i] if results.get('metadatas') else None } chunks.append(chunk) # Sort by chunk_index if available chunks.sort(key=lambda x: x.get( 'metadata', {}).get('chunk_index', 0)) return chunks except Exception as e: logging.error(f"Error retrieving document chunks: {str(e)}") raise def delete_document(self, document_id: str) -> None: """ Delete all chunks associated with a document_id Args: document_id (str): ID of the document to delete """ try: # Get all chunks with the given document_id results = self.collection.get( where={"document_id": document_id}, include=["metadatas"] ) if not results or 'ids' not in results: logging.warning(f"No document found with ID: {document_id}") return # Delete all chunks associated with the document chunk_ids = [ f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))] self.collection.delete(ids=chunk_ids) except Exception as e: logging.error( f"Error deleting document {document_id} from ChromaDB: {str(e)}") raise