advisor / test_embeddings.py
veerukhannan's picture
Create test_embeddings.py
3120cc0 verified
raw
history blame
1.98 kB
import logging
from app import LegalTextSearchBot
import numpy as np
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_embeddings():
try:
logger.info("Initializing LegalTextSearchBot...")
bot = LegalTextSearchBot()
# Test queries
test_queries = [
"What are the penalties for corruption?",
"Explain criminal conspiracy",
"What constitutes culpable homicide?"
]
for query in test_queries:
logger.info(f"\nTesting query: {query}")
# Generate embedding
logger.info("Generating embedding...")
embedding = bot.get_embedding(query)
# Verify embedding
logger.info(f"Embedding dimension: {len(embedding)}")
assert len(embedding) == 1024, f"Embedding dimension should be 1024, got {len(embedding)}"
# Verify embedding values
embedding_array = np.array(embedding)
logger.info(f"Embedding stats - Mean: {embedding_array.mean():.4f}, Std: {embedding_array.std():.4f}")
# Test search
logger.info("Testing vector search...")
results = bot._search_astra(query)
if results:
logger.info(f"Successfully retrieved {len(results)} results")
# Print first result title
logger.info(f"First result: {results[0].get('title', 'No title')}")
else:
logger.warning("No results found")
return True
except Exception as e:
logger.error(f"Test failed: {str(e)}")
return False
if __name__ == "__main__":
print("\n=== Starting Embedding Tests ===\n")
success = test_embeddings()
if success:
print("\n✅ All embedding tests passed!")
else:
print("\n❌ Embedding tests failed!")