Spaces:
Running
Running
Log google drive documents in the mongodb, add source of the document and made chunks to overlap text.
acdfaa9
# src/vectorstores/chroma_vectorstore.py | |
import chromadb | |
from typing import List, Callable, Any, Dict, Optional | |
from chromadb.config import Settings | |
import logging | |
from .base_vectorstore import BaseVectorStore | |
class ChromaVectorStore(BaseVectorStore): | |
def __init__( | |
self, | |
embedding_function: Callable[[List[str]], List[List[float]]], | |
persist_directory: str = './chroma_db', | |
collection_name: str = "documents", | |
client_settings: Optional[Dict[str, Any]] = None | |
): | |
""" | |
Initialize Chroma Vector Store | |
Args: | |
embedding_function (Callable): Function to generate embeddings | |
persist_directory (str): Directory to persist the vector store | |
collection_name (str): Name of the collection to use | |
client_settings (Optional[Dict[str, Any]]): Additional settings for ChromaDB client | |
""" | |
try: | |
settings = Settings( | |
persist_directory=persist_directory, | |
**(client_settings or {}) | |
) | |
self.client = chromadb.PersistentClient(settings=settings) | |
self.collection = self.client.get_or_create_collection( | |
name=collection_name, | |
# Using cosine similarity by default | |
metadata={"hnsw:space": "cosine"} | |
) | |
self.embedding_function = embedding_function | |
except Exception as e: | |
logging.error(f"Error initializing ChromaDB: {str(e)}") | |
raise | |
def add_documents( | |
self, | |
documents: List[str], | |
embeddings: Optional[List[List[float]]] = None, | |
metadatas: Optional[List[Dict[str, Any]]] = None, | |
ids: Optional[List[str]] = None | |
) -> None: | |
""" | |
Add documents to the vector store | |
Args: | |
documents (List[str]): List of document texts | |
embeddings (Optional[List[List[float]]]): Pre-computed embeddings | |
metadatas (Optional[List[Dict[str, Any]]]): Metadata for each document | |
ids (Optional[List[str]]): Custom IDs for the documents | |
""" | |
try: | |
if not documents: | |
logging.warning("No documents provided to add_documents") | |
return | |
if not embeddings: | |
embeddings = self.embedding_function(documents) | |
if len(documents) != len(embeddings): | |
raise ValueError( | |
"Number of documents and embeddings must match") | |
# Use provided IDs or generate them | |
doc_ids = ids if ids is not None else [ | |
f"doc_{i}" for i in range(len(documents))] | |
# Prepare add parameters | |
add_params = { | |
"documents": documents, | |
"embeddings": embeddings, | |
"ids": doc_ids | |
} | |
# Only include metadatas if provided | |
if metadatas is not None: | |
if len(metadatas) != len(documents): | |
raise ValueError( | |
"Number of documents and metadatas must match") | |
add_params["metadatas"] = metadatas | |
self.collection.add(**add_params) | |
except Exception as e: | |
logging.error(f"Error adding documents to ChromaDB: {str(e)}") | |
raise | |
def similarity_search( | |
self, | |
query_embedding: List[float], | |
top_k: int = 3, | |
**kwargs | |
) -> List[Dict[str, Any]]: | |
"""Perform similarity search with improved chunk handling""" | |
try: | |
# Get more initial results to account for sequential chunks | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=max(top_k * 2, 10), | |
include=['documents', 'metadatas', 'distances'] | |
) | |
if not results or 'documents' not in results: | |
return [] | |
formatted_results = [] | |
documents = results['documents'][0] | |
metadatas = results['metadatas'][0] | |
distances = results['distances'][0] | |
# Group chunks by document_id | |
doc_chunks = {} | |
for doc, meta, dist in zip(documents, metadatas, distances): | |
doc_id = meta.get('document_id') | |
chunk_index = meta.get('chunk_index', 0) | |
if doc_id not in doc_chunks: | |
doc_chunks[doc_id] = [] | |
doc_chunks[doc_id].append({ | |
'text': doc, | |
'metadata': meta, | |
'score': 1.0 - dist, | |
'chunk_index': chunk_index | |
}) | |
# Process each document's chunks | |
for doc_id, chunks in doc_chunks.items(): | |
# Sort chunks by index | |
chunks.sort(key=lambda x: x['chunk_index']) | |
# Find sequences of chunks with good scores | |
good_sequences = [] | |
current_sequence = [] | |
for chunk in chunks: | |
if chunk['score'] > 0.3: # Adjust threshold as needed | |
if not current_sequence or \ | |
chunk['chunk_index'] == current_sequence[-1]['chunk_index'] + 1: | |
current_sequence.append(chunk) | |
else: | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
current_sequence = [chunk] | |
else: | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
current_sequence = [] | |
if current_sequence: | |
good_sequences.append(current_sequence) | |
# Add best sequences to results | |
for sequence in good_sequences: | |
avg_score = sum(c['score'] | |
for c in sequence) / len(sequence) | |
combined_text = ' '.join(c['text'] for c in sequence) | |
formatted_results.append({ | |
'text': combined_text, | |
'metadata': sequence[0]['metadata'], | |
'score': avg_score | |
}) | |
# Sort by score and return top_k | |
formatted_results.sort(key=lambda x: x['score'], reverse=True) | |
return formatted_results[:top_k] | |
except Exception as e: | |
logging.error(f"Error in similarity search: {str(e)}") | |
raise | |
def get_all_documents( | |
self, | |
include_embeddings: bool = False | |
) -> List[Dict[str, Any]]: | |
""" | |
Retrieve all documents from the vector store | |
Args: | |
include_embeddings (bool): Whether to include embeddings in the response | |
Returns: | |
List[Dict[str, Any]]: List of documents with their IDs and optionally embeddings | |
""" | |
try: | |
include = ["documents", "metadatas"] | |
if include_embeddings: | |
include.append("embeddings") | |
results = self.collection.get( | |
include=include | |
) | |
if not results or 'documents' not in results: | |
return [] | |
documents = [] | |
for i in range(len(results['documents'])): | |
doc = { | |
'id': str(i), # Generate sequential IDs | |
'text': results['documents'][i], | |
} | |
if include_embeddings and 'embeddings' in results: | |
doc['embedding'] = results['embeddings'][i] | |
if 'metadatas' in results and results['metadatas'][i]: | |
doc['metadata'] = results['metadatas'][i] | |
# Use document_id from metadata if available | |
if 'document_id' in results['metadatas'][i]: | |
doc['id'] = results['metadatas'][i]['document_id'] | |
documents.append(doc) | |
return documents | |
except Exception as e: | |
logging.error( | |
f"Error retrieving documents from ChromaDB: {str(e)}") | |
raise | |
def get_document_chunks(self, document_id: str) -> List[Dict[str, Any]]: | |
""" | |
Retrieve all chunks for a specific document | |
Args: | |
document_id (str): ID of the document to retrieve chunks for | |
Returns: | |
List[Dict[str, Any]]: List of document chunks with their metadata | |
""" | |
try: | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["documents", "metadatas"] | |
) | |
if not results or 'documents' not in results: | |
return [] | |
chunks = [] | |
for i in range(len(results['documents'])): | |
chunk = { | |
'text': results['documents'][i], | |
'metadata': results['metadatas'][i] if results.get('metadatas') else None | |
} | |
chunks.append(chunk) | |
# Sort by chunk_index if available | |
chunks.sort(key=lambda x: x.get( | |
'metadata', {}).get('chunk_index', 0)) | |
return chunks | |
except Exception as e: | |
logging.error(f"Error retrieving document chunks: {str(e)}") | |
raise | |
def delete_document(self, document_id: str) -> None: | |
""" | |
Delete all chunks associated with a document_id | |
Args: | |
document_id (str): ID of the document to delete | |
""" | |
try: | |
# Get all chunks with the given document_id | |
results = self.collection.get( | |
where={"document_id": document_id}, | |
include=["metadatas"] | |
) | |
if not results or 'ids' not in results: | |
logging.warning(f"No document found with ID: {document_id}") | |
return | |
# Delete all chunks associated with the document | |
chunk_ids = [ | |
f"{document_id}-chunk-{i}" for i in range(len(results['metadatas']))] | |
self.collection.delete(ids=chunk_ids) | |
except Exception as e: | |
logging.error( | |
f"Error deleting document {document_id} from ChromaDB: {str(e)}") | |
raise | |