VitaliyPolovyyEN commited on
Commit
4e72327
Β·
verified Β·
1 Parent(s): 7c2596c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -4
app.py CHANGED
@@ -1,7 +1,319 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ import datetime
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import traceback
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ import io
10
 
11
+ # Configuration
12
+ EMBEDDING_MODELS = {
13
+ "sentence-transformers/all-MiniLM-L6-v2": "MiniLM (Multilingual)",
14
+ "ai-forever/FRIDA": "FRIDA (RU-EN)",
15
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "Multilingual MiniLM",
16
+ "cointegrated/rubert-tiny2": "RuBERT Tiny",
17
+ "ai-forever/sbert_large_nlu_ru": "Russian SBERT Large"
18
+ }
19
 
20
+ CHUNK_SIZE = 1024
21
+ CHUNK_OVERLAP = 200
22
+ TOP_K_RESULTS = 4
23
+ OUTPUT_FILENAME = "rag_embedding_test_results.txt"
24
+
25
+ # Global storage
26
+ embeddings_cache = {}
27
+ document_chunks = []
28
+ current_document = ""
29
+
30
+ def chunk_document(text):
31
+ """Split document into chunks using RecursiveCharacterTextSplitter"""
32
+ text_splitter = RecursiveCharacterTextSplitter(
33
+ chunk_size=CHUNK_SIZE,
34
+ chunk_overlap=CHUNK_OVERLAP,
35
+ length_function=len,
36
+ )
37
+ chunks = text_splitter.split_text(text)
38
+ return [chunk for chunk in chunks if len(chunk.strip()) > 50]
39
+
40
+ def test_single_model(model_name, chunks, question):
41
+ """Test embedding with a single model"""
42
+ try:
43
+ start_time = time.time()
44
+
45
+ # Load model
46
+ model = SentenceTransformer(model_name)
47
+ load_time = time.time() - start_time
48
+
49
+ # Create embeddings
50
+ embed_start = time.time()
51
+ chunk_embeddings = model.encode(chunks, show_progress_bar=False)
52
+ question_embedding = model.encode([question], show_progress_bar=False)
53
+ embed_time = time.time() - embed_start
54
+
55
+ # Calculate similarities
56
+ similarities = cosine_similarity(question_embedding, chunk_embeddings)[0]
57
+
58
+ # Get top K results
59
+ top_indices = np.argsort(similarities)[-TOP_K_RESULTS:][::-1]
60
+
61
+ total_time = time.time() - start_time
62
+
63
+ results = {
64
+ 'status': 'success',
65
+ 'total_time': total_time,
66
+ 'load_time': load_time,
67
+ 'embed_time': embed_time,
68
+ 'top_chunks': [
69
+ {
70
+ 'index': idx,
71
+ 'score': similarities[idx],
72
+ 'text': chunks[idx]
73
+ }
74
+ for idx in top_indices
75
+ ]
76
+ }
77
+
78
+ return results
79
+
80
+ except Exception as e:
81
+ return {
82
+ 'status': 'failed',
83
+ 'error': str(e),
84
+ 'traceback': traceback.format_exc()
85
+ }
86
+
87
+ def process_embeddings(document_text, progress=gr.Progress()):
88
+ """Process document with all embedding models"""
89
+ global embeddings_cache, document_chunks, current_document
90
+
91
+ if not document_text.strip():
92
+ return "❌ Please provide document text first!"
93
+
94
+ current_document = document_text
95
+
96
+ # Chunk document
97
+ progress(0.1, desc="Chunking document...")
98
+ document_chunks = chunk_document(document_text)
99
+
100
+ if not document_chunks:
101
+ return "❌ No valid chunks created. Please provide longer text."
102
+
103
+ embeddings_cache = {}
104
+ total_models = len(EMBEDDING_MODELS)
105
+
106
+ progress(0.2, desc=f"Processing {len(document_chunks)} chunks with {total_models} models...")
107
+
108
+ # Process each model
109
+ for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()):
110
+ progress(0.2 + (0.7 * i / total_models), desc=f"Testing {display_name}...")
111
+
112
+ # This is just preparation - we'll process on query
113
+ embeddings_cache[model_name] = {
114
+ 'processed': False,
115
+ 'display_name': display_name
116
+ }
117
+
118
+ progress(1.0, desc="Ready for testing!")
119
+
120
+ return f"βœ… Document processed successfully!\n\nπŸ“Š **Stats:**\n- Total chunks: {len(document_chunks)}\n- Chunk size: {CHUNK_SIZE}\n- Chunk overlap: {CHUNK_OVERLAP}\n- Models ready: {len(EMBEDDING_MODELS)}\n\nπŸ” **Now ask a question to compare embedding models!**"
121
+
122
+ def compare_embeddings(question, progress=gr.Progress()):
123
+ """Compare all models for a given question"""
124
+ global embeddings_cache, document_chunks
125
+
126
+ if not question.strip():
127
+ return "❌ Please enter a question!", ""
128
+
129
+ if not document_chunks:
130
+ return "❌ Please process a document first using 'Start Embedding' button!", ""
131
+
132
+ results = {}
133
+ total_models = len(EMBEDDING_MODELS)
134
+
135
+ # Test each model
136
+ for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()):
137
+ progress(i / total_models, desc=f"Testing {display_name}...")
138
+
139
+ result = test_single_model(model_name, document_chunks, question)
140
+ results[model_name] = result
141
+ results[model_name]['display_name'] = display_name
142
+
143
+ progress(1.0, desc="Comparison complete!")
144
+
145
+ # Format results for display
146
+ display_results = format_comparison_results(results, question)
147
+
148
+ # Generate downloadable report
149
+ report_content = generate_report(results, question)
150
+
151
+ return display_results, report_content
152
+
153
+ def format_comparison_results(results, question):
154
+ """Format results for Gradio display"""
155
+ output = f"# πŸ” Embedding Model Comparison\n\n"
156
+ output += f"**Question:** {question}\n\n"
157
+ output += f"**Document chunks:** {len(document_chunks)}\n\n"
158
+ output += "---\n\n"
159
+
160
+ for model_name, result in results.items():
161
+ display_name = result['display_name']
162
+ output += f"## πŸ€– {display_name}\n\n"
163
+
164
+ if result['status'] == 'success':
165
+ output += f"βœ… **Success** ({result['total_time']:.2f}s)\n\n"
166
+ output += "**Top Results:**\n\n"
167
+
168
+ for i, chunk in enumerate(result['top_chunks'], 1):
169
+ score = chunk['score']
170
+ text_preview = chunk['text'][:200] + "..." if len(chunk['text']) > 200 else chunk['text']
171
+ output += f"**{i}. [{score:.3f}]** Chunk #{chunk['index']}\n"
172
+ output += f"```\n{text_preview}\n```\n\n"
173
+ else:
174
+ output += f"❌ **Failed:** {result['error']}\n\n"
175
+
176
+ output += "---\n\n"
177
+
178
+ return output
179
+
180
+ def generate_report(results, question):
181
+ """Generate downloadable text report"""
182
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
183
+
184
+ report = "==========================================\n"
185
+ report += "RAG EMBEDDING MODEL TEST RESULTS\n"
186
+ report += "==========================================\n"
187
+ report += f"Date: {timestamp}\n"
188
+ report += f"Question: {question}\n"
189
+ report += f"Document chunks: {len(document_chunks)}\n\n"
190
+
191
+ report += "Settings:\n"
192
+ report += f"- Chunk Size: {CHUNK_SIZE}\n"
193
+ report += f"- Chunk Overlap: {CHUNK_OVERLAP}\n"
194
+ report += f"- Splitter: RecursiveCharacterTextSplitter\n"
195
+ report += f"- Top-K Results: {TOP_K_RESULTS}\n\n"
196
+
197
+ report += "==========================================\n"
198
+
199
+ for model_name, result in results.items():
200
+ display_name = result['display_name']
201
+ report += f"MODEL: {display_name}\n"
202
+
203
+ if result['status'] == 'success':
204
+ report += f"Status: βœ… Success ({result['total_time']:.2f}s)\n"
205
+ report += "Top Results:\n"
206
+
207
+ for chunk in result['top_chunks']:
208
+ score = chunk['score']
209
+ text = chunk['text'].replace('\n', ' ')
210
+ text_preview = text[:100] + "..." if len(text) > 100 else text
211
+ report += f"[{score:.3f}] Chunk #{chunk['index']}: \"{text_preview}\"\n"
212
+ else:
213
+ report += f"Status: ❌ Failed - {result['error']}\n"
214
+
215
+ report += "\n" + "="*40 + "\n"
216
+
217
+ return report
218
+
219
+ def load_file(file):
220
+ """Load content from uploaded file"""
221
+ if file is None:
222
+ return ""
223
+
224
+ try:
225
+ content = file.read()
226
+ if isinstance(content, bytes):
227
+ content = content.decode('utf-8')
228
+ return content
229
+ except Exception as e:
230
+ return f"Error loading file: {str(e)}"
231
+
232
+ # Create Gradio interface
233
+ with gr.Blocks(title="RAG Embedding Model Tester", theme=gr.themes.Soft()) as demo:
234
+ gr.Markdown("# πŸ§ͺ RAG Embedding Model Tester")
235
+ gr.Markdown("Test and compare different embedding models for RAG pipelines. Focus on relevance quality assessment.")
236
+
237
+ with gr.Row():
238
+ with gr.Column(scale=1):
239
+ gr.Markdown("## πŸ“„ Document Input")
240
+
241
+ document_input = gr.Textbox(
242
+ lines=15,
243
+ placeholder="Paste your document text here (Russian or English)...",
244
+ label="Document Text",
245
+ max_lines=20
246
+ )
247
+
248
+ file_input = gr.File(
249
+ file_types=[".txt", ".md"],
250
+ label="Or Upload Text File"
251
+ )
252
+
253
+ # Load file content to text box
254
+ file_input.change(
255
+ fn=load_file,
256
+ inputs=file_input,
257
+ outputs=document_input
258
+ )
259
+
260
+ embed_btn = gr.Button("πŸš€ Start Embedding Process", variant="primary", size="lg")
261
+ embed_status = gr.Textbox(label="Processing Status", lines=8)
262
+
263
+ with gr.Column(scale=2):
264
+ gr.Markdown("## ❓ Question & Comparison")
265
+
266
+ question_input = gr.Textbox(
267
+ placeholder="What question do you want to ask about the document?",
268
+ label="Your Question",
269
+ lines=2
270
+ )
271
+
272
+ compare_btn = gr.Button("πŸ” Compare All Models", variant="secondary", size="lg")
273
+
274
+ results_display = gr.Markdown(label="Comparison Results")
275
+
276
+ gr.Markdown("## πŸ“₯ Download Results")
277
+ report_download = gr.File(label="Download Test Report")
278
+
279
+ # Model info
280
+ with gr.Row():
281
+ gr.Markdown(f"""
282
+ ## πŸ€– Models to Test:
283
+ {', '.join([f"**{name}**" for name in EMBEDDING_MODELS.values()])}
284
+
285
+ ## βš™οΈ Settings:
286
+ - **Chunk Size:** {CHUNK_SIZE} characters
287
+ - **Chunk Overlap:** {CHUNK_OVERLAP} characters
288
+ - **Top Results:** {TOP_K_RESULTS} chunks per model
289
+ - **Splitter:** RecursiveCharacterTextSplitter
290
+ """)
291
+
292
+ # Event handlers
293
+ embed_btn.click(
294
+ fn=process_embeddings,
295
+ inputs=document_input,
296
+ outputs=embed_status
297
+ )
298
+
299
+ def compare_and_download(question):
300
+ results_text, report_content = compare_embeddings(question)
301
+
302
+ # Create downloadable file
303
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
304
+ filename = f"rag_test_{timestamp}.txt"
305
+
306
+ # Save report to file-like object
307
+ report_file = io.StringIO(report_content)
308
+ report_file.name = filename
309
+
310
+ return results_text, gr.File.update(value=report_file.getvalue(), visible=True)
311
+
312
+ compare_btn.click(
313
+ fn=compare_and_download,
314
+ inputs=question_input,
315
+ outputs=[results_display, report_download]
316
+ )
317
+
318
+ if __name__ == "__main__":
319
+ demo.launch()