advisor / test_embeddings.py
veerukhannan's picture
Update test_embeddings.py
86b8124 verified
raw
history blame
6.13 kB
import logging
import os
from app import LegalTextSearchBot
import numpy as np
from dotenv import load_dotenv
import time
import gradio as gr
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class TestResults:
def __init__(self):
self.results = []
def add_result(self, test_name, status, message):
self.results.append({
'test_name': test_name,
'status': status,
'message': message,
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
})
def get_markdown_report(self):
report = ["# Test Results\n"]
for result in self.results:
status_emoji = "βœ…" if result['status'] else "❌"
report.append(f"## {status_emoji} {result['test_name']}")
report.append(f"Status: {status_emoji} {'Passed' if result['status'] else 'Failed'}")
report.append(f"Time: {result['timestamp']}")
report.append(f"Details: {result['message']}\n")
return "\n".join(report)
def run_tests(progress=gr.Progress()):
test_results = TestResults()
try:
progress(0, desc="Starting tests...")
# Test 1: Environment Variables
progress(0.1, desc="Checking environment variables...")
try:
load_dotenv()
required_vars = [
"ASTRA_DB_APPLICATION_TOKEN",
"ASTRA_DB_API_ENDPOINT",
"ASTRA_DB_COLLECTION",
"HUGGINGFACE_API_TOKEN"
]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
test_results.add_result(
"Environment Check",
False,
f"Missing environment variables: {missing_vars}"
)
else:
test_results.add_result(
"Environment Check",
True,
"All environment variables present"
)
except Exception as e:
test_results.add_result(
"Environment Check",
False,
f"Error checking environment: {str(e)}"
)
# Test 2: Bot Initialization
progress(0.3, desc="Testing bot initialization...")
try:
bot = LegalTextSearchBot()
test_results.add_result(
"Bot Initialization",
True,
"Successfully initialized LegalTextSearchBot"
)
# Test 3: Embedding Generation
progress(0.5, desc="Testing embedding generation...")
test_queries = [
"What are the penalties for corruption?",
"Explain criminal conspiracy",
"What constitutes culpable homicide?"
]
embedding_results = []
for query in test_queries:
embedding = bot.get_embedding(query)
embedding_array = np.array(embedding)
embedding_results.append({
'query': query,
'dimension': len(embedding),
'mean': embedding_array.mean(),
'std': embedding_array.std()
})
test_results.add_result(
"Embedding Generation",
True,
f"Generated embeddings for {len(test_queries)} queries\n" +
"\n".join([f"Query: {r['query'][:50]}...\n"
f"Dimension: {r['dimension']}\n"
f"Mean: {r['mean']:.4f}, Std: {r['std']:.4f}\n"
for r in embedding_results])
)
# Test 4: Search Functionality
progress(0.7, desc="Testing search functionality...")
search_results = []
for query in test_queries:
start_time = time.time()
results = bot._search_astra(query)
elapsed_time = time.time() - start_time
search_results.append({
'query': query,
'num_results': len(results),
'time': elapsed_time
})
test_results.add_result(
"Search Functionality",
True,
f"Completed searches for {len(test_queries)} queries\n" +
"\n".join([f"Query: {r['query'][:50]}...\n"
f"Results found: {r['num_results']}\n"
f"Search time: {r['time']:.2f}s\n"
for r in search_results])
)
except Exception as e:
test_results.add_result(
"Bot Tests",
False,
f"Error during bot tests: {str(e)}"
)
progress(1.0, desc="Tests completed!")
return test_results.get_markdown_report()
except Exception as e:
return f"# ❌ Test Suite Failed\n\nError: {str(e)}"
def create_test_interface():
with gr.Blocks(title="Legal Search System Tests") as iface:
gr.Markdown("""
# πŸ§ͺ Legal Search System Test Suite
This interface runs comprehensive tests on the legal search system components:
1. Environment Configuration
2. Bot Initialization
3. Embedding Generation
4. Search Functionality
""")
with gr.Row():
run_button = gr.Button("πŸš€ Run Tests", variant="primary")
with gr.Row():
output = gr.Markdown("Click 'Run Tests' to start testing...")
run_button.click(
fn=run_tests,
outputs=output
)
return iface
if __name__ == "__main__":
demo = create_test_interface()
demo.launch()