veerukhannan commited on
Commit
86b8124
·
verified ·
1 Parent(s): 19f6421

Update test_embeddings.py

Browse files
Files changed (1) hide show
  1. test_embeddings.py +147 -128
test_embeddings.py CHANGED
@@ -4,158 +4,177 @@ from app import LegalTextSearchBot
4
  import numpy as np
5
  from dotenv import load_dotenv
6
  import time
7
- from tqdm import tqdm
8
 
9
  # Configure logging
10
  logging.basicConfig(
11
  level=logging.INFO,
12
- format='%(asctime)s - %(levelname)s - %(message)s',
13
- handlers=[
14
- logging.StreamHandler(),
15
- logging.FileHandler('embedding_tests.log')
16
- ]
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
- def test_environment():
21
- """Test environment variables and connections"""
22
- try:
23
- load_dotenv()
24
- required_vars = [
25
- "ASTRA_DB_APPLICATION_TOKEN",
26
- "ASTRA_DB_API_ENDPOINT",
27
- "ASTRA_DB_COLLECTION",
28
- "HUGGINGFACE_API_TOKEN"
29
- ]
30
 
31
- missing_vars = [var for var in required_vars if not os.getenv(var)]
32
- if missing_vars:
33
- logger.error(f"Missing environment variables: {missing_vars}")
34
- return False
35
-
36
- logger.info("✅ Environment variables verified")
37
- return True
38
 
39
- except Exception as e:
40
- logger.error(f"Environment test failed: {str(e)}")
41
- return False
42
-
43
- def test_bot_initialization():
44
- """Test LegalTextSearchBot initialization"""
45
- try:
46
- bot = LegalTextSearchBot()
47
- logger.info("✅ Bot initialization successful")
48
- return bot
49
- except Exception as e:
50
- logger.error(f"Bot initialization failed: {str(e)}")
51
- return None
52
 
53
- def test_embedding_generation(bot):
54
- """Test embedding generation"""
 
55
  try:
56
- test_queries = [
57
- "What are the penalties for corruption?",
58
- "Explain criminal conspiracy",
59
- "What constitutes culpable homicide?",
60
- "", # Test empty string
61
- " ", # Test whitespace
62
- "a" * 1000, # Test long string
63
- "Section 123 of IPC", # Test with numbers
64
- "धारा 123", # Test with non-English
65
- ]
66
 
67
- logger.info("Testing embedding generation...")
68
- for query in tqdm(test_queries, desc="Testing queries"):
69
- embedding = bot.get_embedding(query)
 
 
 
 
 
 
 
70
 
71
- # Verify embedding dimension
72
- assert len(embedding) == 1024, f"Wrong embedding dimension: {len(embedding)}"
73
-
74
- # Verify embedding values
75
- embedding_array = np.array(embedding)
76
- assert not np.isnan(embedding_array).any(), "Embedding contains NaN values"
77
- assert not np.isinf(embedding_array).any(), "Embedding contains infinite values"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Log embedding statistics
80
- logger.debug(f"Query: {query[:50]}...")
81
- logger.debug(f"Embedding stats - Mean: {embedding_array.mean():.4f}, Std: {embedding_array.std():.4f}")
 
 
 
 
82
 
83
- logger.info("✅ Embedding generation tests passed")
84
- return True
85
-
86
- except Exception as e:
87
- logger.error(f"Embedding generation test failed: {str(e)}")
88
- return False
89
-
90
- def test_search_functionality(bot):
91
- """Test search functionality"""
92
- try:
93
- test_queries = [
94
- "What are the penalties for corruption?",
95
- "Explain criminal conspiracy",
96
- "What constitutes culpable homicide?"
97
- ]
98
-
99
- logger.info("Testing search functionality...")
100
- for query in tqdm(test_queries, desc="Testing searches"):
101
- start_time = time.time()
102
 
103
- # Test vector search
104
- results = bot._search_astra(query)
 
 
 
 
 
 
 
105
 
106
- # Log search performance
107
- elapsed_time = time.time() - start_time
108
- logger.info(f"Search time for '{query[:50]}...': {elapsed_time:.2f}s")
 
 
 
 
 
 
 
 
 
 
109
 
110
- # Verify results
111
- assert isinstance(results, list), "Search results should be a list"
112
- if results:
113
- logger.info(f"Found {len(results)} results for '{query[:50]}...'")
114
- # Verify result structure
115
- first_result = results[0]
116
- required_fields = ["section_number", "title", "content"]
117
- for field in required_fields:
118
- assert field in first_result, f"Missing required field: {field}"
119
 
120
- logger.info("✅ Search functionality tests passed")
121
- return True
 
 
 
 
 
 
 
122
 
123
  except Exception as e:
124
- logger.error(f"Search functionality test failed: {str(e)}")
125
- return False
126
 
127
- def run_all_tests():
128
- """Run all tests"""
129
- try:
130
- logger.info("\n=== Starting Comprehensive Tests ===\n")
131
 
132
- # Test 1: Environment
133
- if not test_environment():
134
- return False
135
-
136
- # Test 2: Bot Initialization
137
- bot = test_bot_initialization()
138
- if not bot:
139
- return False
140
-
141
- # Test 3: Embedding Generation
142
- if not test_embedding_generation(bot):
143
- return False
144
 
145
- # Test 4: Search Functionality
146
- if not test_search_functionality(bot):
147
- return False
148
 
149
- logger.info("\n=== All Tests Completed Successfully ===\n")
150
- return True
151
-
152
- except Exception as e:
153
- logger.error(f"Test suite failed: {str(e)}")
154
- return False
155
 
156
  if __name__ == "__main__":
157
- success = run_all_tests()
158
- if success:
159
- print("\n✅ All tests passed successfully!")
160
- else:
161
- print("\n❌ Some tests failed. Check the logs for details.")
 
4
  import numpy as np
5
  from dotenv import load_dotenv
6
  import time
7
+ import gradio as gr
8
 
9
  # Configure logging
10
  logging.basicConfig(
11
  level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s'
 
 
 
 
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
+ class TestResults:
17
+ def __init__(self):
18
+ self.results = []
 
 
 
 
 
 
 
19
 
20
+ def add_result(self, test_name, status, message):
21
+ self.results.append({
22
+ 'test_name': test_name,
23
+ 'status': status,
24
+ 'message': message,
25
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
26
+ })
27
 
28
+ def get_markdown_report(self):
29
+ report = ["# Test Results\n"]
30
+ for result in self.results:
31
+ status_emoji = "✅" if result['status'] else "❌"
32
+ report.append(f"## {status_emoji} {result['test_name']}")
33
+ report.append(f"Status: {status_emoji} {'Passed' if result['status'] else 'Failed'}")
34
+ report.append(f"Time: {result['timestamp']}")
35
+ report.append(f"Details: {result['message']}\n")
36
+ return "\n".join(report)
 
 
 
 
37
 
38
+ def run_tests(progress=gr.Progress()):
39
+ test_results = TestResults()
40
+
41
  try:
42
+ progress(0, desc="Starting tests...")
 
 
 
 
 
 
 
 
 
43
 
44
+ # Test 1: Environment Variables
45
+ progress(0.1, desc="Checking environment variables...")
46
+ try:
47
+ load_dotenv()
48
+ required_vars = [
49
+ "ASTRA_DB_APPLICATION_TOKEN",
50
+ "ASTRA_DB_API_ENDPOINT",
51
+ "ASTRA_DB_COLLECTION",
52
+ "HUGGINGFACE_API_TOKEN"
53
+ ]
54
 
55
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
56
+ if missing_vars:
57
+ test_results.add_result(
58
+ "Environment Check",
59
+ False,
60
+ f"Missing environment variables: {missing_vars}"
61
+ )
62
+ else:
63
+ test_results.add_result(
64
+ "Environment Check",
65
+ True,
66
+ "All environment variables present"
67
+ )
68
+
69
+ except Exception as e:
70
+ test_results.add_result(
71
+ "Environment Check",
72
+ False,
73
+ f"Error checking environment: {str(e)}"
74
+ )
75
+
76
+ # Test 2: Bot Initialization
77
+ progress(0.3, desc="Testing bot initialization...")
78
+ try:
79
+ bot = LegalTextSearchBot()
80
+ test_results.add_result(
81
+ "Bot Initialization",
82
+ True,
83
+ "Successfully initialized LegalTextSearchBot"
84
+ )
85
 
86
+ # Test 3: Embedding Generation
87
+ progress(0.5, desc="Testing embedding generation...")
88
+ test_queries = [
89
+ "What are the penalties for corruption?",
90
+ "Explain criminal conspiracy",
91
+ "What constitutes culpable homicide?"
92
+ ]
93
 
94
+ embedding_results = []
95
+ for query in test_queries:
96
+ embedding = bot.get_embedding(query)
97
+ embedding_array = np.array(embedding)
98
+
99
+ embedding_results.append({
100
+ 'query': query,
101
+ 'dimension': len(embedding),
102
+ 'mean': embedding_array.mean(),
103
+ 'std': embedding_array.std()
104
+ })
 
 
 
 
 
 
 
 
105
 
106
+ test_results.add_result(
107
+ "Embedding Generation",
108
+ True,
109
+ f"Generated embeddings for {len(test_queries)} queries\n" +
110
+ "\n".join([f"Query: {r['query'][:50]}...\n"
111
+ f"Dimension: {r['dimension']}\n"
112
+ f"Mean: {r['mean']:.4f}, Std: {r['std']:.4f}\n"
113
+ for r in embedding_results])
114
+ )
115
 
116
+ # Test 4: Search Functionality
117
+ progress(0.7, desc="Testing search functionality...")
118
+ search_results = []
119
+ for query in test_queries:
120
+ start_time = time.time()
121
+ results = bot._search_astra(query)
122
+ elapsed_time = time.time() - start_time
123
+
124
+ search_results.append({
125
+ 'query': query,
126
+ 'num_results': len(results),
127
+ 'time': elapsed_time
128
+ })
129
 
130
+ test_results.add_result(
131
+ "Search Functionality",
132
+ True,
133
+ f"Completed searches for {len(test_queries)} queries\n" +
134
+ "\n".join([f"Query: {r['query'][:50]}...\n"
135
+ f"Results found: {r['num_results']}\n"
136
+ f"Search time: {r['time']:.2f}s\n"
137
+ for r in search_results])
138
+ )
139
 
140
+ except Exception as e:
141
+ test_results.add_result(
142
+ "Bot Tests",
143
+ False,
144
+ f"Error during bot tests: {str(e)}"
145
+ )
146
+
147
+ progress(1.0, desc="Tests completed!")
148
+ return test_results.get_markdown_report()
149
 
150
  except Exception as e:
151
+ return f"# Test Suite Failed\n\nError: {str(e)}"
 
152
 
153
+ def create_test_interface():
154
+ with gr.Blocks(title="Legal Search System Tests") as iface:
155
+ gr.Markdown("""
156
+ # 🧪 Legal Search System Test Suite
157
 
158
+ This interface runs comprehensive tests on the legal search system components:
159
+ 1. Environment Configuration
160
+ 2. Bot Initialization
161
+ 3. Embedding Generation
162
+ 4. Search Functionality
163
+ """)
164
+
165
+ with gr.Row():
166
+ run_button = gr.Button("🚀 Run Tests", variant="primary")
 
 
 
167
 
168
+ with gr.Row():
169
+ output = gr.Markdown("Click 'Run Tests' to start testing...")
 
170
 
171
+ run_button.click(
172
+ fn=run_tests,
173
+ outputs=output
174
+ )
175
+
176
+ return iface
177
 
178
  if __name__ == "__main__":
179
+ demo = create_test_interface()
180
+ demo.launch()