veerukhannan commited on
Commit
3120cc0
·
verified ·
1 Parent(s): 693be4f

Create test_embeddings.py

Browse files
Files changed (1) hide show
  1. test_embeddings.py +58 -0
test_embeddings.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from app import LegalTextSearchBot
3
+ import numpy as np
4
+
5
+ logging.basicConfig(level=logging.INFO)
6
+ logger = logging.getLogger(__name__)
7
+
8
+ def test_embeddings():
9
+ try:
10
+ logger.info("Initializing LegalTextSearchBot...")
11
+ bot = LegalTextSearchBot()
12
+
13
+ # Test queries
14
+ test_queries = [
15
+ "What are the penalties for corruption?",
16
+ "Explain criminal conspiracy",
17
+ "What constitutes culpable homicide?"
18
+ ]
19
+
20
+ for query in test_queries:
21
+ logger.info(f"\nTesting query: {query}")
22
+
23
+ # Generate embedding
24
+ logger.info("Generating embedding...")
25
+ embedding = bot.get_embedding(query)
26
+
27
+ # Verify embedding
28
+ logger.info(f"Embedding dimension: {len(embedding)}")
29
+ assert len(embedding) == 1024, f"Embedding dimension should be 1024, got {len(embedding)}"
30
+
31
+ # Verify embedding values
32
+ embedding_array = np.array(embedding)
33
+ logger.info(f"Embedding stats - Mean: {embedding_array.mean():.4f}, Std: {embedding_array.std():.4f}")
34
+
35
+ # Test search
36
+ logger.info("Testing vector search...")
37
+ results = bot._search_astra(query)
38
+
39
+ if results:
40
+ logger.info(f"Successfully retrieved {len(results)} results")
41
+ # Print first result title
42
+ logger.info(f"First result: {results[0].get('title', 'No title')}")
43
+ else:
44
+ logger.warning("No results found")
45
+
46
+ return True
47
+
48
+ except Exception as e:
49
+ logger.error(f"Test failed: {str(e)}")
50
+ return False
51
+
52
+ if __name__ == "__main__":
53
+ print("\n=== Starting Embedding Tests ===\n")
54
+ success = test_embeddings()
55
+ if success:
56
+ print("\n✅ All embedding tests passed!")
57
+ else:
58
+ print("\n❌ Embedding tests failed!")