Spaces:
Running
Running
# src/vectorstores/chroma_vectorstore.py | |
import chromadb | |
from typing import List, Callable, Any, Dict, Optional | |
from chromadb.config import Settings | |
import logging | |
from .base_vectorstore import BaseVectorStore | |
class ChromaVectorStore(BaseVectorStore): | |
def __init__( | |
self, | |
embedding_function: Callable[[List[str]], List[List[float]]], | |
persist_directory: str = './chroma_db', | |
collection_name: str = "documents", | |
client_settings: Optional[Dict[str, Any]] = None | |
): | |
""" | |
Initialize Chroma Vector Store | |
Args: | |
embedding_function (Callable): Function to generate embeddings | |
persist_directory (str): Directory to persist the vector store | |
collection_name (str): Name of the collection to use | |
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client | |
""" | |
try: | |
settings = Settings( | |
persist_directory=persist_directory, | |
**(client_settings or {}) | |
) | |
self.client = chromadb.PersistentClient(settings=settings) | |
self.collection = self.client.get_or_create_collection( | |
name=collection_name, | |
# Using cosine similarity by default | |
metadata={"hnsw:space": "cosine"} | |
) | |
self.embedding_function = embedding_function | |
except Exception as e: | |
logging.error(f"Error initializing ChromaDB: {str(e)}") | |
raise | |
def add_documents( | |
self, | |
documents: List[str], | |
embeddings: Optional[List[List[float]]] = None, | |
metadatas: Optional[List[Dict[str, Any]]] = None, | |
ids: Optional[List[str]] = None | |
) -> None: | |
""" | |
Add documents to the vector store | |
Args: | |
documents (List[str]): List of document texts | |
embeddings (Optional[List[List[float]]]): Pre-computed embeddings | |
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document | |
ids (Optional[List[str]]): Custom IDs for the documents | |
""" | |
try: | |
if not documents: | |
logging.warning("No documents provided to add_documents") | |
return | |
if not embeddings: | |
embeddings = self.embedding_function(documents) | |
if len(documents) != len(embeddings): | |
raise ValueError( | |
"Number of documents and embeddings must match") | |
# Use provided IDs or generate them | |
doc_ids = ids if ids is not None else [ | |
f"doc_{i}" for i in range(len(documents))] | |
# Prepare add parameters | |
add_params = { | |
"documents": documents, | |
"embeddings": embeddings, | |
"ids": doc_ids | |
} | |
# Only include metadatas if provided | |
if metadatas is not None: | |
if len(metadatas) != len(documents): | |
raise ValueError( | |
"Number of documents and metadatas must match") | |
add_params["metadatas"] = metadatas | |
self.collection.add(**add_params) | |
except Exception as e: | |
logging.error(f"Error adding documents to ChromaDB: {str(e)}") | |
raise | |
def similarity_search( | |
self, | |
query_embedding: List[float], | |
top_k: int = 3, | |
**kwargs | |
) -> List[Dict[str, Any]]: | |
""" | |
Perform similarity search with improved matching | |
""" | |
try: | |
# Increase n_results to get more potential matches | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=10, # Get more initial results | |
include=['documents', 'metadatas', 'distances'] | |
) | |
if not results or 'documents' not in results or not results['documents']: | |
logging.warning("No results found in similarity search") | |
return [] | |
formatted_results = [] | |
documents = results['documents'][0] # First query's results | |
metadatas = results['metadatas'][0] if results.get('metadatas') else [ | |
None] * len(documents) | |
distances = results['distances'][0] if results.get('distances') else [ | |
None] * len(documents) | |
# Process all results | |
for doc, meta, dist in zip(documents, metadatas, distances): | |
# Convert distance to similarity score (1 is most similar, 0 is least) | |
similarity_score = 1.0 - \ | |
(dist or 0.0) if dist is not None else None | |
# More permissive threshold and include all results for filtering | |
if similarity_score is not None and similarity_score > 0.2: # Lower threshold | |
formatted_results.append({ | |
'text': doc, | |
'metadata': meta or {}, | |
'score': similarity_score | |
}) | |
# Sort by score and get top_k results | |
formatted_results.sort(key=lambda x: x['score'] or 0, reverse=True) | |
# Check if results are from same document and get consecutive chunks | |
if formatted_results: | |
first_doc_id = formatted_results[0]['metadata'].get( | |
'document_id') | |
all_chunks_same_doc = [] | |
# Get all chunks from the same document | |
for result in formatted_results: | |
if result['metadata'].get('document_id') == first_doc_id: | |
all_chunks_same_doc.append(result) | |
# Sort chunks by their index to maintain document flow | |
all_chunks_same_doc.sort( | |
key=lambda x: x['metadata'].get('chunk_index', 0) | |
) | |
# Return either all chunks from same document or top_k results | |
if len(all_chunks_same_doc) > 0: | |
return all_chunks_same_doc[:top_k] | |
return formatted_results[:top_k] | |
except Exception as e: | |
logging.error( | |
f"Error performing similarity search in ChromaDB: {str(e)}") | |
raise | |
def get_all_documents( | |
self, | |
include_embeddings: bool = False | |
) -> List[Dict[str, Any]]: | |
""" | |
Retrieve all documents from the vector store | |
Args: | |
include_embeddings (bool): Whether to include embeddings in the response | |
Returns: | |
List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings | |
""" | |
try: | |
include = ["documents", "metadatas"] | |
if include_embeddings: | |
include.append("embeddings") | |
results = self.collection.get( | |
include=include | |
) | |
if not results or 'documents' not in results: | |
return [] | |
documents = [] | |
for i in range(len(results['documents'])): | |
doc = { | |
'id': str(i), # Generate sequential IDs | |
'text': results['documents'][i], | |
} | |
if include_embeddings and 'embeddings' in results: | |
doc['embedding'] = results['embeddings'][i] | |
if 'metadatas' in results and results['metadatas'][i]: | |
doc['metadata'] = results['metadatas'][i] | |
# Use document_id from metadata if available | |
if 'document_id' in results['metadatas'][i]: | |
doc['id'] = results['metadatas'][i]['document_id'] | |
documents.append(doc) | |
return documents | |
except Exception as e: | |
logging.error( | |
f"Error retrieving documents from ChromaDB: {str(e)}") | |
raise | |
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]: | |
""" | |
Retrieve all chunks for a specific document | |
Args: | |
document_id (str): ID of the document to retrieve chunks for | |
Returns: | |
List[Dict[str, Any]]: List of document chunks with their metadata | |
""" | |
try: | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["documents", "metadatas"] | |
) | |
if not results or 'documents' not in results: | |
return [] | |
chunks = [] | |
for i in range(len(results['documents'])): | |
chunk = { | |
'text': results['documents'][i], | |
'metadata': results['metadatas'][i] if results.get('metadatas') else None | |
} | |
chunks.append(chunk) | |
# Sort by chunk_index if available | |
chunks.sort(key=lambda x: x.get( | |
'metadata', {}).get('chunk_index', 0)) | |
return chunks | |
except Exception as e: | |
logging.error(f"Error retrieving document chunks: {str(e)}") | |
raise | |
def delete_document(self, document_id: str) -> None: | |
""" | |
Delete all chunks associated with a document_id | |
Args: | |
document_id (str): ID of the document to delete | |
""" | |
try: | |
# Get all chunks with the given document_id | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["metadatas"] | |
) | |
if not results or 'ids' not in results: | |
logging.warning(f"No document found with ID: {document_id}") | |
return | |
# Delete all chunks associated with the document | |
chunk_ids = [ | |
f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))] | |
self.collection.delete(ids=chunk_ids) | |
except Exception as e: | |
logging.error( | |
f"Error deleting document {document_id} from ChromaDB: {str(e)}") | |
raise | |