Spaces:
Running
Running
# src/vectorstores/chroma_vectorstore.py | |
from pathlib import Path | |
import chromadb | |
from typing import List, Callable, Any, Dict, Optional | |
import logging | |
import asyncio | |
from .base_vectorstore import BaseVectorStore | |
from .chroma_manager import ChromaManager | |
class ChromaVectorStore(BaseVectorStore): | |
def __init__( | |
self, | |
embedding_function: Callable[[List[str]], List[List[float]]], | |
persist_directory: str = './chroma_db', | |
collection_name: str = "documents", | |
client_settings: Optional[Dict[str, Any]] = None, | |
client=None # Allow passing an existing client | |
): | |
""" | |
Initialize Chroma Vector Store | |
Args: | |
embedding_function (Callable): Function to generate embeddings | |
persist_directory (str): Directory to persist the vector store | |
collection_name (str): Name of the collection to use | |
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client | |
client: Optional existing ChromaDB client to use | |
""" | |
self.embedding_function = embedding_function | |
self.persist_directory = persist_directory | |
self.collection_name = collection_name | |
self.client = client # Store client for later initialization | |
# Will be populated during async initialization | |
self.collection = None | |
self.initialized = False | |
async def initialize(self): | |
"""Asynchronously initialize the vector store with enhanced error handling""" | |
if self.initialized: | |
return | |
try: | |
# Get client via manager if not provided | |
if self.client is None: | |
self.client = await ChromaManager.get_client(self.persist_directory) | |
# Validate client | |
if not self.client: | |
raise ValueError("Failed to obtain ChromaDB client") | |
# Get or create collection with more robust handling | |
try: | |
self.collection = await ChromaManager.get_or_create_collection( | |
client=self.client, | |
collection_name=self.collection_name, | |
embedding_dimension=1024 # Default for most models | |
) | |
except Exception as collection_error: | |
logging.error( | |
f"Error creating collection: {str(collection_error)}") | |
# Try to reset and recreate | |
try: | |
# Attempt to delete existing collection | |
self.client.delete_collection(self.collection_name) | |
except: | |
pass | |
# Recreate collection | |
self.collection = self.client.create_collection( | |
name=self.collection_name, | |
metadata={"hnsw:space": "cosine"} | |
) | |
# Additional validation | |
if not self.collection: | |
raise ValueError( | |
"Failed to create or obtain ChromaDB collection") | |
self.initialized = True | |
logging.info( | |
f"ChromaVectorStore initialized with collection: {self.collection_name}") | |
except Exception as e: | |
logging.error( | |
f"Critical error initializing ChromaVectorStore: {str(e)}") | |
# Reset initialization state | |
self.initialized = False | |
self.collection = None | |
raise | |
async def _ensure_initialized(self): | |
"""Make sure the vector store is initialized before use""" | |
if not self.initialized: | |
await self.initialize() | |
async def add_documents_async( | |
self, | |
documents: List[str], | |
embeddings: Optional[List[List[float]]] = None, | |
metadatas: Optional[List[Dict[str, Any]]] = None, | |
ids: Optional[List[str]] = None | |
) -> None: | |
""" | |
Add documents asynchronously with enhanced error handling | |
Args: | |
documents (List[str]): List of document texts | |
embeddings (Optional[List[List[float]]]): Pre-computed embeddings | |
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document | |
ids (Optional[List[str]]): Custom IDs for the documents | |
""" | |
await self._ensure_initialized() | |
if not documents: | |
logging.warning("No documents provided to add_documents") | |
return | |
# Validate input lists | |
if embeddings and len(documents) != len(embeddings): | |
raise ValueError("Number of documents and embeddings must match") | |
if metadatas and len(documents) != len(metadatas): | |
raise ValueError("Number of documents and metadatas must match") | |
# Generate embeddings if not provided | |
if not embeddings: | |
try: | |
embeddings = self.embedding_function(documents) | |
except Exception as e: | |
logging.error(f"Error generating embeddings: {str(e)}") | |
raise | |
# Use provided IDs or generate them | |
if not ids: | |
ids = [f"doc_{i}" for i in range(len(documents))] | |
# Ensure collection exists and is usable | |
if not self.collection: | |
logging.error("ChromaDB collection is not initialized") | |
await self.initialize() | |
# Prepare add parameters | |
add_params = { | |
"documents": documents, | |
"embeddings": embeddings, | |
"ids": ids | |
} | |
# Add metadatas if provided | |
if metadatas is not None: | |
add_params["metadatas"] = metadatas | |
try: | |
# Add documents to collection with retry mechanism | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
# Clear any cached state | |
import gc | |
gc.collect() | |
# Attempt to add documents | |
self.collection.add(**add_params) | |
logging.info( | |
f"Successfully added {len(documents)} documents") | |
break | |
except (StopIteration, RuntimeError) as retry_error: | |
if attempt < max_retries - 1: | |
logging.warning( | |
f"Retry attempt {attempt + 1}: {str(retry_error)}") | |
# Optional: Add a small delay between retries | |
await asyncio.sleep(0.5) | |
else: | |
logging.error( | |
f"Failed to add documents after {max_retries} attempts") | |
raise | |
except (StopIteration, RuntimeError) as retry_error: | |
if attempt < max_retries - 1: | |
logging.warning( | |
f"Retry attempt {attempt + 1}: {str(retry_error)}") | |
# Optional: Add a small delay between retries | |
await asyncio.sleep(0.5) | |
else: | |
logging.error( | |
f"Failed to add documents after {max_retries} attempts") | |
raise | |
except Exception as e: | |
logging.error( | |
f"Unexpected error adding documents to ChromaDB: {str(e)}") | |
# Additional debugging information | |
try: | |
logging.info(f"Collection status: {self.collection}") | |
logging.info(f"Documents count: {len(documents)}") | |
logging.info( | |
f"Embeddings count: {len(add_params.get('embeddings', []))}") | |
logging.info( | |
f"Metadatas count: {len(add_params.get('metadatas', []))}") | |
logging.info(f"IDs count: {len(add_params.get('ids', []))}") | |
except Exception as debug_error: | |
logging.error(f"Error during debugging: {str(debug_error)}") | |
def add_documents( | |
self, | |
documents: List[str], | |
embeddings: Optional[List[List[float]]] = None, | |
metadatas: Optional[List[Dict[str, Any]]] = None, | |
ids: Optional[List[str]] = None | |
) -> None: | |
""" | |
Synchronous wrapper for add_documents_async | |
""" | |
# Create and run a new event loop if needed | |
try: | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# Create a future that can be run in the existing loop | |
asyncio.create_task(self.add_documents_async( | |
documents, embeddings, metadatas, ids | |
)) | |
else: | |
# Run in a new event loop | |
loop.run_until_complete(self.add_documents_async( | |
documents, embeddings, metadatas, ids | |
)) | |
except RuntimeError: | |
# No event loop, create a new one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(self.add_documents_async( | |
documents, embeddings, metadatas, ids | |
)) | |
async def similarity_search_async( | |
self, | |
query_embedding: List[float], | |
top_k: int = 3, | |
**kwargs | |
) -> List[Dict[str, Any]]: | |
""" | |
Perform similarity search asynchronously | |
""" | |
await self._ensure_initialized() | |
try: | |
# Get more initial results to account for sequential chunks | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=max(top_k * 2, 10), | |
include=['documents', 'metadatas', 'distances'] | |
) | |
if not results or 'documents' not in results: | |
return [] | |
formatted_results = [] | |
documents = results['documents'][0] | |
metadatas = results['metadatas'][0] | |
distances = results['distances'][0] | |
# Group chunks by document_id | |
doc_chunks = {} | |
for doc, meta, dist in zip(documents, metadatas, distances): | |
doc_id = meta.get('document_id') | |
chunk_index = meta.get('chunk_index', 0) | |
if doc_id not in doc_chunks: | |
doc_chunks[doc_id] = [] | |
doc_chunks[doc_id].append({ | |
'text': doc, | |
'metadata': meta, | |
'score': 1.0 - dist, | |
'chunk_index': chunk_index | |
}) | |
# Process each document's chunks | |
for doc_id, chunks in doc_chunks.items(): | |
# Sort chunks by index | |
chunks.sort(key=lambda x: x['chunk_index']) | |
# Find sequences of chunks with good scores | |
good_sequences = [] | |
current_sequence = [] | |
for chunk in chunks: | |
if chunk['score'] > 0.3: # Adjust threshold as needed | |
if not current_sequence or \ | |
chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1: | |
current_sequence.append(chunk) | |
else: | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
current_sequence = [chunk] | |
else: | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
current_sequence = [] | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
# Add best sequences to results | |
for sequence in good_sequences: | |
avg_score = sum(c['score'] | |
for c in sequence) / len(sequence) | |
combined_text = ' '.join(c['text'] for c in sequence) | |
formatted_results.append({ | |
'text': combined_text, | |
'metadata': sequence[0]['metadata'], | |
'score': avg_score | |
}) | |
# Sort by score and return top_k | |
formatted_results.sort(key=lambda x: x['score'], reverse=True) | |
return formatted_results[:top_k] | |
except Exception as e: | |
logging.error(f"Error in similarity search: {str(e)}") | |
raise | |
def similarity_search( | |
self, | |
query_embedding: List[float], | |
top_k: int = 3, | |
**kwargs | |
) -> List[Dict[str, Any]]: | |
""" | |
Synchronous wrapper for similarity_search_async | |
""" | |
try: | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# We're in an async context, but need to process directly | |
try: | |
# Get more initial results to account for sequential chunks | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=max(top_k * 2, 10), | |
include=['documents', 'metadatas', 'distances'] | |
) | |
if not results or 'documents' not in results: | |
return [] | |
formatted_results = [] | |
documents = results['documents'][0] | |
metadatas = results['metadatas'][0] | |
distances = results['distances'][0] | |
# Group chunks by document_id | |
doc_chunks = {} | |
for doc, meta, dist in zip(documents, metadatas, distances): | |
doc_id = meta.get('document_id') | |
chunk_index = meta.get('chunk_index', 0) | |
if doc_id not in doc_chunks: | |
doc_chunks[doc_id] = [] | |
doc_chunks[doc_id].append({ | |
'text': doc, | |
'metadata': meta, | |
'score': 1.0 - dist, | |
'chunk_index': chunk_index | |
}) | |
# Process each document's chunks | |
for doc_id, chunks in doc_chunks.items(): | |
# Sort chunks by index | |
chunks.sort(key=lambda x: x['chunk_index']) | |
# Find sequences of chunks with good scores | |
good_sequences = [] | |
current_sequence = [] | |
for chunk in chunks: | |
if chunk['score'] > 0.3: # Adjust threshold as needed | |
if not current_sequence or \ | |
chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1: | |
current_sequence.append(chunk) | |
else: | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
current_sequence = [chunk] | |
else: | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
current_sequence = [] | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
# Add best sequences to results | |
for sequence in good_sequences: | |
avg_score = sum(c['score'] | |
for c in sequence) / len(sequence) | |
combined_text = ' '.join( | |
c['text'] for c in sequence) | |
formatted_results.append({ | |
'text': combined_text, | |
'metadata': sequence[0]['metadata'], | |
'score': avg_score | |
}) | |
# Sort by score and return top_k | |
formatted_results.sort( | |
key=lambda x: x['score'], reverse=True) | |
return formatted_results[:top_k] | |
except Exception as e: | |
logging.error( | |
f"Error in direct similarity search: {str(e)}") | |
return [] | |
else: | |
# Run in existing loop | |
return loop.run_until_complete( | |
self.similarity_search_async( | |
query_embedding, top_k, **kwargs) | |
) | |
except RuntimeError: | |
# No event loop, create a new one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop.run_until_complete( | |
self.similarity_search_async(query_embedding, top_k, **kwargs) | |
) | |
async def get_all_documents_async( | |
self, | |
include_embeddings: bool = False | |
) -> List[Dict[str, Any]]: | |
""" | |
Retrieve all documents asynchronously | |
""" | |
await self._ensure_initialized() | |
try: | |
include = ["documents", "metadatas"] | |
if include_embeddings: | |
include.append("embeddings") | |
results = self.collection.get( | |
include=include | |
) | |
if not results or 'documents' not in results: | |
return [] | |
documents = [] | |
for i in range(len(results['documents'])): | |
doc = { | |
'id': str(i), # Generate sequential IDs | |
'text': results['documents'][i], | |
} | |
if include_embeddings and 'embeddings' in results: | |
doc['embedding'] = results['embeddings'][i] | |
if 'metadatas' in results and results['metadatas'][i]: | |
doc['metadata'] = results['metadatas'][i] | |
# Use document_id from metadata if available | |
if 'document_id' in results['metadatas'][i]: | |
doc['id'] = results['metadatas'][i]['document_id'] | |
documents.append(doc) | |
return documents | |
except Exception as e: | |
logging.error( | |
f"Error retrieving documents from ChromaDB: {str(e)}") | |
raise | |
def get_all_documents( | |
self, | |
include_embeddings: bool = False | |
) -> List[Dict[str, Any]]: | |
""" | |
Synchronous wrapper for get_all_documents_async | |
""" | |
try: | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# We're in an async context, but need to return synchronously | |
# Process the results just like in the async version | |
try: | |
include = ["documents", "metadatas"] | |
if include_embeddings: | |
include.append("embeddings") | |
results = self.collection.get( | |
include=include | |
) | |
if not results or 'documents' not in results: | |
return [] | |
documents = [] | |
for i in range(len(results['documents'])): | |
doc = { | |
'id': str(i), # Generate sequential IDs | |
'text': results['documents'][i], | |
} | |
if include_embeddings and 'embeddings' in results: | |
doc['embedding'] = results['embeddings'][i] | |
if 'metadatas' in results and results['metadatas'][i]: | |
doc['metadata'] = results['metadatas'][i] | |
# Use document_id from metadata if available | |
if 'document_id' in results['metadatas'][i]: | |
doc['id'] = results['metadatas'][i]['document_id'] | |
documents.append(doc) | |
return documents | |
except: | |
return [] | |
else: | |
return loop.run_until_complete( | |
self.get_all_documents_async(include_embeddings) | |
) | |
except RuntimeError: | |
# No event loop, create a new one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop.run_until_complete( | |
self.get_all_documents_async(include_embeddings) | |
) | |
async def get_document_chunks_async(self, document_id: str) -> List[Dict[str, Any]]: | |
""" | |
Retrieve all chunks for a specific document asynchronously | |
""" | |
await self._ensure_initialized() | |
try: | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["documents", "metadatas"] | |
) | |
if not results or 'documents' not in results: | |
return [] | |
chunks = [] | |
for i in range(len(results['documents'])): | |
chunk = { | |
'text': results['documents'][i], | |
'metadata': results['metadatas'][i] if results.get('metadatas') else None | |
} | |
chunks.append(chunk) | |
# Sort by chunk_index if available | |
chunks.sort(key=lambda x: x.get( | |
'metadata', {}).get('chunk_index', 0)) | |
return chunks | |
except Exception as e: | |
logging.error(f"Error retrieving document chunks: {str(e)}") | |
raise | |
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]: | |
""" | |
Synchronous wrapper for get_document_chunks_async | |
""" | |
try: | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# Fall back to direct query which may fail | |
try: | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["documents", "metadatas"] | |
) | |
chunks = [] | |
for i in range(len(results['documents'])): | |
chunk = { | |
'text': results['documents'][i], | |
'metadata': results['metadatas'][i] if results.get('metadatas') else None | |
} | |
chunks.append(chunk) | |
return chunks | |
except: | |
return [] | |
else: | |
return loop.run_until_complete( | |
self.get_document_chunks_async(document_id) | |
) | |
except RuntimeError: | |
# No event loop, create a new one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop.run_until_complete( | |
self.get_document_chunks_async(document_id) | |
) | |
async def delete_document_async(self, document_id: str) -> None: | |
""" | |
Delete all chunks associated with a document_id asynchronously | |
""" | |
await self._ensure_initialized() | |
try: | |
# Get all chunks with the given document_id | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["ids"] | |
) | |
if not results or 'ids' not in results: | |
logging.warning(f"No document found with ID: {document_id}") | |
return | |
# Delete all chunks associated with the document | |
self.collection.delete(ids=results['ids']) | |
except Exception as e: | |
logging.error( | |
f"Error deleting document {document_id} from ChromaDB: {str(e)}") | |
raise | |
def delete_document(self, document_id: str) -> None: | |
""" | |
Synchronous wrapper for delete_document_async | |
""" | |
try: | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# Create a future that can be run in the existing loop | |
asyncio.create_task(self.delete_document_async(document_id)) | |
else: | |
# Run in a new event loop | |
loop.run_until_complete( | |
self.delete_document_async(document_id)) | |
except RuntimeError: | |
# No event loop, create a new one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(self.delete_document_async(document_id)) | |