advisor / test_embeddings.py
veerukhannan's picture
Create test_embeddings.py
6a1ad16 verified
raw
history blame
2.64 kB
import os
import chromadb
from sentence_transformers import SentenceTransformer
from loguru import logger
class SentenceTransformerEmbeddings:
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
self.model = SentenceTransformer(model_name)
def __call__(self, input: list[str]) -> list[list[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
def test_chromadb_content():
"""Test if ChromaDB has the required content"""
try:
# Set up ChromaDB path
base_path = os.path.dirname(os.path.abspath(__file__))
chroma_path = os.path.join(base_path, 'chroma_db')
if not os.path.exists(chroma_path):
logger.error(f"ChromaDB directory not found at {chroma_path}")
return False
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path=chroma_path)
# Check if collection exists
collections = chroma_client.list_collections()
if not any(col.name == "legal_documents" for col in collections):
logger.error("Legal documents collection not found in ChromaDB")
return False
# Get collection
collection = chroma_client.get_collection(
name="legal_documents",
embedding_function=SentenceTransformerEmbeddings()
)
# Check collection size
count = collection.count()
if count == 0:
logger.error("Collection is empty")
return False
logger.info(f"Found {count} documents in ChromaDB")
# Test query to verify content
test_results = collection.query(
query_texts=["What are the general provisions?"],
n_results=1
)
if not test_results['documents']:
logger.error("Test query returned no results")
return False
# Print sample content
logger.info("Sample content from ChromaDB:")
for i, (doc, metadata) in enumerate(zip(test_results['documents'][0], test_results['metadatas'][0])):
logger.info(f"\nDocument {i+1}:")
logger.info(f"Title: {metadata['title']}")
logger.info(f"Content preview: {doc[:200]}...")
return True
except Exception as e:
logger.error(f"Error testing ChromaDB: {str(e)}")
return False
if __name__ == "__main__":
success = test_chromadb_content()
if success:
print("ChromaDB content verification successful")
else:
print("ChromaDB content verification failed")