veerukhannan commited on
Commit
19f6421
·
verified ·
1 Parent(s): 065dd0b

Update test_embeddings.py

Browse files
Files changed (1) hide show
  1. test_embeddings.py +131 -28
test_embeddings.py CHANGED
@@ -1,58 +1,161 @@
1
  import logging
 
2
  from app import LegalTextSearchBot
3
  import numpy as np
 
 
 
4
 
5
- logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
 
 
6
  logger = logging.getLogger(__name__)
7
 
8
- def test_embeddings():
 
9
  try:
10
- logger.info("Initializing LegalTextSearchBot...")
11
- bot = LegalTextSearchBot()
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Test queries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  test_queries = [
15
  "What are the penalties for corruption?",
16
  "Explain criminal conspiracy",
17
- "What constitutes culpable homicide?"
 
 
 
 
 
18
  ]
19
 
20
- for query in test_queries:
21
- logger.info(f"\nTesting query: {query}")
22
-
23
- # Generate embedding
24
- logger.info("Generating embedding...")
25
  embedding = bot.get_embedding(query)
26
 
27
- # Verify embedding
28
- logger.info(f"Embedding dimension: {len(embedding)}")
29
- assert len(embedding) == 1024, f"Embedding dimension should be 1024, got {len(embedding)}"
30
 
31
  # Verify embedding values
32
  embedding_array = np.array(embedding)
33
- logger.info(f"Embedding stats - Mean: {embedding_array.mean():.4f}, Std: {embedding_array.std():.4f}")
 
 
 
 
 
34
 
35
- # Test search
36
- logger.info("Testing vector search...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  results = bot._search_astra(query)
38
 
 
 
 
 
 
 
39
  if results:
40
- logger.info(f"Successfully retrieved {len(results)} results")
41
- # Print first result title
42
- logger.info(f"First result: {results[0].get('title', 'No title')}")
43
- else:
44
- logger.warning("No results found")
45
-
 
 
46
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
49
- logger.error(f"Test failed: {str(e)}")
50
  return False
51
 
52
  if __name__ == "__main__":
53
- print("\n=== Starting Embedding Tests ===\n")
54
- success = test_embeddings()
55
  if success:
56
- print("\n✅ All embedding tests passed!")
57
  else:
58
- print("\n❌ Embedding tests failed!")
 
1
  import logging
2
+ import os
3
  from app import LegalTextSearchBot
4
  import numpy as np
5
+ from dotenv import load_dotenv
6
+ import time
7
+ from tqdm import tqdm
8
 
9
+ # Configure logging
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s',
13
+ handlers=[
14
+ logging.StreamHandler(),
15
+ logging.FileHandler('embedding_tests.log')
16
+ ]
17
+ )
18
  logger = logging.getLogger(__name__)
19
 
20
+ def test_environment():
21
+ """Test environment variables and connections"""
22
  try:
23
+ load_dotenv()
24
+ required_vars = [
25
+ "ASTRA_DB_APPLICATION_TOKEN",
26
+ "ASTRA_DB_API_ENDPOINT",
27
+ "ASTRA_DB_COLLECTION",
28
+ "HUGGINGFACE_API_TOKEN"
29
+ ]
30
+
31
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
32
+ if missing_vars:
33
+ logger.error(f"Missing environment variables: {missing_vars}")
34
+ return False
35
+
36
+ logger.info("✅ Environment variables verified")
37
+ return True
38
 
39
+ except Exception as e:
40
+ logger.error(f"Environment test failed: {str(e)}")
41
+ return False
42
+
43
+ def test_bot_initialization():
44
+ """Test LegalTextSearchBot initialization"""
45
+ try:
46
+ bot = LegalTextSearchBot()
47
+ logger.info("✅ Bot initialization successful")
48
+ return bot
49
+ except Exception as e:
50
+ logger.error(f"Bot initialization failed: {str(e)}")
51
+ return None
52
+
53
+ def test_embedding_generation(bot):
54
+ """Test embedding generation"""
55
+ try:
56
  test_queries = [
57
  "What are the penalties for corruption?",
58
  "Explain criminal conspiracy",
59
+ "What constitutes culpable homicide?",
60
+ "", # Test empty string
61
+ " ", # Test whitespace
62
+ "a" * 1000, # Test long string
63
+ "Section 123 of IPC", # Test with numbers
64
+ "धारा 123", # Test with non-English
65
  ]
66
 
67
+ logger.info("Testing embedding generation...")
68
+ for query in tqdm(test_queries, desc="Testing queries"):
 
 
 
69
  embedding = bot.get_embedding(query)
70
 
71
+ # Verify embedding dimension
72
+ assert len(embedding) == 1024, f"Wrong embedding dimension: {len(embedding)}"
 
73
 
74
  # Verify embedding values
75
  embedding_array = np.array(embedding)
76
+ assert not np.isnan(embedding_array).any(), "Embedding contains NaN values"
77
+ assert not np.isinf(embedding_array).any(), "Embedding contains infinite values"
78
+
79
+ # Log embedding statistics
80
+ logger.debug(f"Query: {query[:50]}...")
81
+ logger.debug(f"Embedding stats - Mean: {embedding_array.mean():.4f}, Std: {embedding_array.std():.4f}")
82
 
83
+ logger.info("✅ Embedding generation tests passed")
84
+ return True
85
+
86
+ except Exception as e:
87
+ logger.error(f"Embedding generation test failed: {str(e)}")
88
+ return False
89
+
90
+ def test_search_functionality(bot):
91
+ """Test search functionality"""
92
+ try:
93
+ test_queries = [
94
+ "What are the penalties for corruption?",
95
+ "Explain criminal conspiracy",
96
+ "What constitutes culpable homicide?"
97
+ ]
98
+
99
+ logger.info("Testing search functionality...")
100
+ for query in tqdm(test_queries, desc="Testing searches"):
101
+ start_time = time.time()
102
+
103
+ # Test vector search
104
  results = bot._search_astra(query)
105
 
106
+ # Log search performance
107
+ elapsed_time = time.time() - start_time
108
+ logger.info(f"Search time for '{query[:50]}...': {elapsed_time:.2f}s")
109
+
110
+ # Verify results
111
+ assert isinstance(results, list), "Search results should be a list"
112
  if results:
113
+ logger.info(f"Found {len(results)} results for '{query[:50]}...'")
114
+ # Verify result structure
115
+ first_result = results[0]
116
+ required_fields = ["section_number", "title", "content"]
117
+ for field in required_fields:
118
+ assert field in first_result, f"Missing required field: {field}"
119
+
120
+ logger.info("✅ Search functionality tests passed")
121
  return True
122
+
123
+ except Exception as e:
124
+ logger.error(f"Search functionality test failed: {str(e)}")
125
+ return False
126
+
127
+ def run_all_tests():
128
+ """Run all tests"""
129
+ try:
130
+ logger.info("\n=== Starting Comprehensive Tests ===\n")
131
+
132
+ # Test 1: Environment
133
+ if not test_environment():
134
+ return False
135
 
136
+ # Test 2: Bot Initialization
137
+ bot = test_bot_initialization()
138
+ if not bot:
139
+ return False
140
+
141
+ # Test 3: Embedding Generation
142
+ if not test_embedding_generation(bot):
143
+ return False
144
+
145
+ # Test 4: Search Functionality
146
+ if not test_search_functionality(bot):
147
+ return False
148
+
149
+ logger.info("\n=== All Tests Completed Successfully ===\n")
150
+ return True
151
+
152
  except Exception as e:
153
+ logger.error(f"Test suite failed: {str(e)}")
154
  return False
155
 
156
  if __name__ == "__main__":
157
+ success = run_all_tests()
 
158
  if success:
159
+ print("\n✅ All tests passed successfully!")
160
  else:
161
+ print("\n❌ Some tests failed. Check the logs for details.")