File size: 6,133 Bytes
3120cc0
19f6421
3120cc0
 
19f6421
 
86b8124
3120cc0
19f6421
 
 
86b8124
19f6421
3120cc0
 
86b8124
 
 
19f6421
86b8124
 
 
 
 
 
 
3120cc0
86b8124
 
 
 
 
 
 
 
 
19f6421
86b8124
 
 
19f6421
86b8124
3120cc0
86b8124
 
 
 
 
 
 
 
 
 
3120cc0
86b8124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19f6421
86b8124
 
 
 
 
 
 
3120cc0
86b8124
 
 
 
 
 
 
 
 
 
 
19f6421
86b8124
 
 
 
 
 
 
 
 
3120cc0
86b8124
 
 
 
 
 
 
 
 
 
 
 
 
19f6421
86b8124
 
 
 
 
 
 
 
 
19f6421
86b8124
 
 
 
 
 
 
 
 
19f6421
 
86b8124
19f6421
86b8124
 
 
 
19f6421
86b8124
 
 
 
 
 
 
 
 
19f6421
86b8124
 
19f6421
86b8124
 
 
 
 
 
3120cc0
 
86b8124
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import logging
import os
from app import LegalTextSearchBot
import numpy as np
from dotenv import load_dotenv
import time
import gradio as gr

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class TestResults:
    def __init__(self):
        self.results = []
        
    def add_result(self, test_name, status, message):
        self.results.append({
            'test_name': test_name,
            'status': status,
            'message': message,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
        })
        
    def get_markdown_report(self):
        report = ["# Test Results\n"]
        for result in self.results:
            status_emoji = "βœ…" if result['status'] else "❌"
            report.append(f"## {status_emoji} {result['test_name']}")
            report.append(f"Status: {status_emoji} {'Passed' if result['status'] else 'Failed'}")
            report.append(f"Time: {result['timestamp']}")
            report.append(f"Details: {result['message']}\n")
        return "\n".join(report)

def run_tests(progress=gr.Progress()):
    test_results = TestResults()
    
    try:
        progress(0, desc="Starting tests...")
        
        # Test 1: Environment Variables
        progress(0.1, desc="Checking environment variables...")
        try:
            load_dotenv()
            required_vars = [
                "ASTRA_DB_APPLICATION_TOKEN",
                "ASTRA_DB_API_ENDPOINT",
                "ASTRA_DB_COLLECTION",
                "HUGGINGFACE_API_TOKEN"
            ]
            
            missing_vars = [var for var in required_vars if not os.getenv(var)]
            if missing_vars:
                test_results.add_result(
                    "Environment Check",
                    False,
                    f"Missing environment variables: {missing_vars}"
                )
            else:
                test_results.add_result(
                    "Environment Check",
                    True,
                    "All environment variables present"
                )
                
        except Exception as e:
            test_results.add_result(
                "Environment Check",
                False,
                f"Error checking environment: {str(e)}"
            )
        
        # Test 2: Bot Initialization
        progress(0.3, desc="Testing bot initialization...")
        try:
            bot = LegalTextSearchBot()
            test_results.add_result(
                "Bot Initialization",
                True,
                "Successfully initialized LegalTextSearchBot"
            )
            
            # Test 3: Embedding Generation
            progress(0.5, desc="Testing embedding generation...")
            test_queries = [
                "What are the penalties for corruption?",
                "Explain criminal conspiracy",
                "What constitutes culpable homicide?"
            ]
            
            embedding_results = []
            for query in test_queries:
                embedding = bot.get_embedding(query)
                embedding_array = np.array(embedding)
                
                embedding_results.append({
                    'query': query,
                    'dimension': len(embedding),
                    'mean': embedding_array.mean(),
                    'std': embedding_array.std()
                })
            
            test_results.add_result(
                "Embedding Generation",
                True,
                f"Generated embeddings for {len(test_queries)} queries\n" +
                "\n".join([f"Query: {r['query'][:50]}...\n"
                          f"Dimension: {r['dimension']}\n"
                          f"Mean: {r['mean']:.4f}, Std: {r['std']:.4f}\n"
                          for r in embedding_results])
            )
            
            # Test 4: Search Functionality
            progress(0.7, desc="Testing search functionality...")
            search_results = []
            for query in test_queries:
                start_time = time.time()
                results = bot._search_astra(query)
                elapsed_time = time.time() - start_time
                
                search_results.append({
                    'query': query,
                    'num_results': len(results),
                    'time': elapsed_time
                })
            
            test_results.add_result(
                "Search Functionality",
                True,
                f"Completed searches for {len(test_queries)} queries\n" +
                "\n".join([f"Query: {r['query'][:50]}...\n"
                          f"Results found: {r['num_results']}\n"
                          f"Search time: {r['time']:.2f}s\n"
                          for r in search_results])
            )
            
        except Exception as e:
            test_results.add_result(
                "Bot Tests",
                False,
                f"Error during bot tests: {str(e)}"
            )
        
        progress(1.0, desc="Tests completed!")
        return test_results.get_markdown_report()
        
    except Exception as e:
        return f"# ❌ Test Suite Failed\n\nError: {str(e)}"

def create_test_interface():
    with gr.Blocks(title="Legal Search System Tests") as iface:
        gr.Markdown("""
        # πŸ§ͺ Legal Search System Test Suite
        
        This interface runs comprehensive tests on the legal search system components:
        1. Environment Configuration
        2. Bot Initialization
        3. Embedding Generation
        4. Search Functionality
        """)
        
        with gr.Row():
            run_button = gr.Button("πŸš€ Run Tests", variant="primary")
            
        with gr.Row():
            output = gr.Markdown("Click 'Run Tests' to start testing...")
            
        run_button.click(
            fn=run_tests,
            outputs=output
        )
    
    return iface

if __name__ == "__main__":
    demo = create_test_interface()
    demo.launch()