# src/vectorstores/chroma_vectorstore.py import chromadb from typing import List, Callable, Any, Dict, Optional from chromadb.config import Settings import logging from .base_vectorstore import BaseVectorStore class ChromaVectorStore(BaseVectorStore): def __init__( self, embedding_function: Callable[[List[str]], List[List[float]]], persist_directory: str = './chroma_db', collection_name: str = "documents", client_settings: Optional[Dict[str, Any]] = None ): """ Initialize Chroma Vector Store Args: embedding_function (Callable): Function to generate embeddings persist_directory (str): Directory to persist the vector store collection_name (str): Name of the collection to use client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client """ try: settings = Settings( persist_directory=persist_directory, **(client_settings or {}) ) self.client = chromadb.PersistentClient(settings=settings) self.collection = self.client.get_or_create_collection( name=collection_name, metadata={"hnsw:space": "cosine"} # Using cosine similarity by default ) self.embedding_function = embedding_function except Exception as e: logging.error(f"Error initializing ChromaDB: {str(e)}") raise def add_documents( self, documents: List[str], embeddings: Optional[List[List[float]]] = None, metadatas: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None ) -> None: """ Add documents to the vector store Args: documents (List[str]): List of document texts embeddings (Optional[List[List[float]]]): Pre-computed embeddings metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document ids (Optional[List[str]]): Custom IDs for the documents """ try: if not documents: logging.warning("No documents provided to add_documents") return if not embeddings: embeddings = self.embedding_function(documents) if len(documents) != len(embeddings): raise ValueError("Number of documents and embeddings must match") # Use provided IDs or generate them doc_ids = ids if ids is not None else [f"doc_{i}" for i in range(len(documents))] # Prepare add parameters add_params = { "documents": documents, "embeddings": embeddings, "ids": doc_ids } # Only include metadatas if provided if metadatas is not None: if len(metadatas) != len(documents): raise ValueError("Number of documents and metadatas must match") add_params["metadatas"] = metadatas self.collection.add(**add_params) except Exception as e: logging.error(f"Error adding documents to ChromaDB: {str(e)}") raise def similarity_search( self, query_embedding: List[float], top_k: int = 3, **kwargs ) -> List[Dict[str, Any]]: """ Perform similarity search Args: query_embedding (List[float]): Embedding of the query top_k (int): Number of top similar documents to retrieve **kwargs: Additional search parameters Returns: List[Dict[str, Any]]: List of documents with their text, metadata, and scores """ try: results = self.collection.query( query_embeddings=[query_embedding], n_results=top_k, include=['documents', 'metadatas', 'distances'] ) # Handle the case where no results are found if not results or 'documents' not in results or not results['documents']: return [] # Format results to include text, metadata, and scores formatted_results = [] documents = results['documents'][0] # First query's results metadatas = results['metadatas'][0] if results.get('metadatas') else [None] * len(documents) distances = results['distances'][0] if results.get('distances') else [None] * len(documents) for doc, meta, dist in zip(documents, metadatas, distances): formatted_results.append({ 'text': doc, 'metadata': meta or {}, 'score': 1.0 - (dist or 0.0) if dist is not None else None # Convert distance to similarity score }) return formatted_results except Exception as e: logging.error(f"Error performing similarity search in ChromaDB: {str(e)}") raise def get_all_documents( self, include_embeddings: bool = False ) -> List[Dict[str, Any]]: """ Retrieve all documents from the vector store Args: include_embeddings (bool): Whether to include embeddings in the response Returns: List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings """ try: include = ["documents", "metadatas"] if include_embeddings: include.append("embeddings") results = self.collection.get( include=include ) if not results or 'documents' not in results: return [] documents = [] for i in range(len(results['documents'])): doc = { 'id': str(i), # Generate sequential IDs 'text': results['documents'][i], } if include_embeddings and 'embeddings' in results: doc['embedding'] = results['embeddings'][i] if 'metadatas' in results and results['metadatas'][i]: doc['metadata'] = results['metadatas'][i] # Use document_id from metadata if available if 'document_id' in results['metadatas'][i]: doc['id'] = results['metadatas'][i]['document_id'] documents.append(doc) return documents except Exception as e: logging.error(f"Error retrieving documents from ChromaDB: {str(e)}") raise def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]: """ Retrieve all chunks for a specific document Args: document_id (str): ID of the document to retrieve chunks for Returns: List[Dict[str, Any]]: List of document chunks with their metadata """ try: results = self.collection.get( where={"document_id": document_id}, include=["documents", "metadatas"] ) if not results or 'documents' not in results: return [] chunks = [] for i in range(len(results['documents'])): chunk = { 'text': results['documents'][i], 'metadata': results['metadatas'][i] if results.get('metadatas') else None } chunks.append(chunk) # Sort by chunk_index if available chunks.sort(key=lambda x: x.get('metadata', {}).get('chunk_index', 0)) return chunks except Exception as e: logging.error(f"Error retrieving document chunks: {str(e)}") raise def delete_document(self, document_id: str) -> None: """ Delete all chunks associated with a document_id Args: document_id (str): ID of the document to delete """ try: # Get all chunks with the given document_id results = self.collection.get( where={"document_id": document_id}, include=["metadatas"] ) if not results or 'ids' not in results: logging.warning(f"No document found with ID: {document_id}") return # Delete all chunks associated with the document chunk_ids = [f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))] self.collection.delete(ids=chunk_ids) except Exception as e: logging.error(f"Error deleting document {document_id} from ChromaDB: {str(e)}") raise