Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,319 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import time
|
3 |
+
import datetime
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
import traceback
|
8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
+
import io
|
10 |
|
11 |
+
# Configuration
|
12 |
+
EMBEDDING_MODELS = {
|
13 |
+
"sentence-transformers/all-MiniLM-L6-v2": "MiniLM (Multilingual)",
|
14 |
+
"ai-forever/FRIDA": "FRIDA (RU-EN)",
|
15 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "Multilingual MiniLM",
|
16 |
+
"cointegrated/rubert-tiny2": "RuBERT Tiny",
|
17 |
+
"ai-forever/sbert_large_nlu_ru": "Russian SBERT Large"
|
18 |
+
}
|
19 |
|
20 |
+
CHUNK_SIZE = 1024
|
21 |
+
CHUNK_OVERLAP = 200
|
22 |
+
TOP_K_RESULTS = 4
|
23 |
+
OUTPUT_FILENAME = "rag_embedding_test_results.txt"
|
24 |
+
|
25 |
+
# Global storage
|
26 |
+
embeddings_cache = {}
|
27 |
+
document_chunks = []
|
28 |
+
current_document = ""
|
29 |
+
|
30 |
+
def chunk_document(text):
|
31 |
+
"""Split document into chunks using RecursiveCharacterTextSplitter"""
|
32 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
33 |
+
chunk_size=CHUNK_SIZE,
|
34 |
+
chunk_overlap=CHUNK_OVERLAP,
|
35 |
+
length_function=len,
|
36 |
+
)
|
37 |
+
chunks = text_splitter.split_text(text)
|
38 |
+
return [chunk for chunk in chunks if len(chunk.strip()) > 50]
|
39 |
+
|
40 |
+
def test_single_model(model_name, chunks, question):
|
41 |
+
"""Test embedding with a single model"""
|
42 |
+
try:
|
43 |
+
start_time = time.time()
|
44 |
+
|
45 |
+
# Load model
|
46 |
+
model = SentenceTransformer(model_name)
|
47 |
+
load_time = time.time() - start_time
|
48 |
+
|
49 |
+
# Create embeddings
|
50 |
+
embed_start = time.time()
|
51 |
+
chunk_embeddings = model.encode(chunks, show_progress_bar=False)
|
52 |
+
question_embedding = model.encode([question], show_progress_bar=False)
|
53 |
+
embed_time = time.time() - embed_start
|
54 |
+
|
55 |
+
# Calculate similarities
|
56 |
+
similarities = cosine_similarity(question_embedding, chunk_embeddings)[0]
|
57 |
+
|
58 |
+
# Get top K results
|
59 |
+
top_indices = np.argsort(similarities)[-TOP_K_RESULTS:][::-1]
|
60 |
+
|
61 |
+
total_time = time.time() - start_time
|
62 |
+
|
63 |
+
results = {
|
64 |
+
'status': 'success',
|
65 |
+
'total_time': total_time,
|
66 |
+
'load_time': load_time,
|
67 |
+
'embed_time': embed_time,
|
68 |
+
'top_chunks': [
|
69 |
+
{
|
70 |
+
'index': idx,
|
71 |
+
'score': similarities[idx],
|
72 |
+
'text': chunks[idx]
|
73 |
+
}
|
74 |
+
for idx in top_indices
|
75 |
+
]
|
76 |
+
}
|
77 |
+
|
78 |
+
return results
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
return {
|
82 |
+
'status': 'failed',
|
83 |
+
'error': str(e),
|
84 |
+
'traceback': traceback.format_exc()
|
85 |
+
}
|
86 |
+
|
87 |
+
def process_embeddings(document_text, progress=gr.Progress()):
|
88 |
+
"""Process document with all embedding models"""
|
89 |
+
global embeddings_cache, document_chunks, current_document
|
90 |
+
|
91 |
+
if not document_text.strip():
|
92 |
+
return "β Please provide document text first!"
|
93 |
+
|
94 |
+
current_document = document_text
|
95 |
+
|
96 |
+
# Chunk document
|
97 |
+
progress(0.1, desc="Chunking document...")
|
98 |
+
document_chunks = chunk_document(document_text)
|
99 |
+
|
100 |
+
if not document_chunks:
|
101 |
+
return "β No valid chunks created. Please provide longer text."
|
102 |
+
|
103 |
+
embeddings_cache = {}
|
104 |
+
total_models = len(EMBEDDING_MODELS)
|
105 |
+
|
106 |
+
progress(0.2, desc=f"Processing {len(document_chunks)} chunks with {total_models} models...")
|
107 |
+
|
108 |
+
# Process each model
|
109 |
+
for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()):
|
110 |
+
progress(0.2 + (0.7 * i / total_models), desc=f"Testing {display_name}...")
|
111 |
+
|
112 |
+
# This is just preparation - we'll process on query
|
113 |
+
embeddings_cache[model_name] = {
|
114 |
+
'processed': False,
|
115 |
+
'display_name': display_name
|
116 |
+
}
|
117 |
+
|
118 |
+
progress(1.0, desc="Ready for testing!")
|
119 |
+
|
120 |
+
return f"β
Document processed successfully!\n\nπ **Stats:**\n- Total chunks: {len(document_chunks)}\n- Chunk size: {CHUNK_SIZE}\n- Chunk overlap: {CHUNK_OVERLAP}\n- Models ready: {len(EMBEDDING_MODELS)}\n\nπ **Now ask a question to compare embedding models!**"
|
121 |
+
|
122 |
+
def compare_embeddings(question, progress=gr.Progress()):
|
123 |
+
"""Compare all models for a given question"""
|
124 |
+
global embeddings_cache, document_chunks
|
125 |
+
|
126 |
+
if not question.strip():
|
127 |
+
return "β Please enter a question!", ""
|
128 |
+
|
129 |
+
if not document_chunks:
|
130 |
+
return "β Please process a document first using 'Start Embedding' button!", ""
|
131 |
+
|
132 |
+
results = {}
|
133 |
+
total_models = len(EMBEDDING_MODELS)
|
134 |
+
|
135 |
+
# Test each model
|
136 |
+
for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()):
|
137 |
+
progress(i / total_models, desc=f"Testing {display_name}...")
|
138 |
+
|
139 |
+
result = test_single_model(model_name, document_chunks, question)
|
140 |
+
results[model_name] = result
|
141 |
+
results[model_name]['display_name'] = display_name
|
142 |
+
|
143 |
+
progress(1.0, desc="Comparison complete!")
|
144 |
+
|
145 |
+
# Format results for display
|
146 |
+
display_results = format_comparison_results(results, question)
|
147 |
+
|
148 |
+
# Generate downloadable report
|
149 |
+
report_content = generate_report(results, question)
|
150 |
+
|
151 |
+
return display_results, report_content
|
152 |
+
|
153 |
+
def format_comparison_results(results, question):
|
154 |
+
"""Format results for Gradio display"""
|
155 |
+
output = f"# π Embedding Model Comparison\n\n"
|
156 |
+
output += f"**Question:** {question}\n\n"
|
157 |
+
output += f"**Document chunks:** {len(document_chunks)}\n\n"
|
158 |
+
output += "---\n\n"
|
159 |
+
|
160 |
+
for model_name, result in results.items():
|
161 |
+
display_name = result['display_name']
|
162 |
+
output += f"## π€ {display_name}\n\n"
|
163 |
+
|
164 |
+
if result['status'] == 'success':
|
165 |
+
output += f"β
**Success** ({result['total_time']:.2f}s)\n\n"
|
166 |
+
output += "**Top Results:**\n\n"
|
167 |
+
|
168 |
+
for i, chunk in enumerate(result['top_chunks'], 1):
|
169 |
+
score = chunk['score']
|
170 |
+
text_preview = chunk['text'][:200] + "..." if len(chunk['text']) > 200 else chunk['text']
|
171 |
+
output += f"**{i}. [{score:.3f}]** Chunk #{chunk['index']}\n"
|
172 |
+
output += f"```\n{text_preview}\n```\n\n"
|
173 |
+
else:
|
174 |
+
output += f"β **Failed:** {result['error']}\n\n"
|
175 |
+
|
176 |
+
output += "---\n\n"
|
177 |
+
|
178 |
+
return output
|
179 |
+
|
180 |
+
def generate_report(results, question):
|
181 |
+
"""Generate downloadable text report"""
|
182 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
183 |
+
|
184 |
+
report = "==========================================\n"
|
185 |
+
report += "RAG EMBEDDING MODEL TEST RESULTS\n"
|
186 |
+
report += "==========================================\n"
|
187 |
+
report += f"Date: {timestamp}\n"
|
188 |
+
report += f"Question: {question}\n"
|
189 |
+
report += f"Document chunks: {len(document_chunks)}\n\n"
|
190 |
+
|
191 |
+
report += "Settings:\n"
|
192 |
+
report += f"- Chunk Size: {CHUNK_SIZE}\n"
|
193 |
+
report += f"- Chunk Overlap: {CHUNK_OVERLAP}\n"
|
194 |
+
report += f"- Splitter: RecursiveCharacterTextSplitter\n"
|
195 |
+
report += f"- Top-K Results: {TOP_K_RESULTS}\n\n"
|
196 |
+
|
197 |
+
report += "==========================================\n"
|
198 |
+
|
199 |
+
for model_name, result in results.items():
|
200 |
+
display_name = result['display_name']
|
201 |
+
report += f"MODEL: {display_name}\n"
|
202 |
+
|
203 |
+
if result['status'] == 'success':
|
204 |
+
report += f"Status: β
Success ({result['total_time']:.2f}s)\n"
|
205 |
+
report += "Top Results:\n"
|
206 |
+
|
207 |
+
for chunk in result['top_chunks']:
|
208 |
+
score = chunk['score']
|
209 |
+
text = chunk['text'].replace('\n', ' ')
|
210 |
+
text_preview = text[:100] + "..." if len(text) > 100 else text
|
211 |
+
report += f"[{score:.3f}] Chunk #{chunk['index']}: \"{text_preview}\"\n"
|
212 |
+
else:
|
213 |
+
report += f"Status: β Failed - {result['error']}\n"
|
214 |
+
|
215 |
+
report += "\n" + "="*40 + "\n"
|
216 |
+
|
217 |
+
return report
|
218 |
+
|
219 |
+
def load_file(file):
|
220 |
+
"""Load content from uploaded file"""
|
221 |
+
if file is None:
|
222 |
+
return ""
|
223 |
+
|
224 |
+
try:
|
225 |
+
content = file.read()
|
226 |
+
if isinstance(content, bytes):
|
227 |
+
content = content.decode('utf-8')
|
228 |
+
return content
|
229 |
+
except Exception as e:
|
230 |
+
return f"Error loading file: {str(e)}"
|
231 |
+
|
232 |
+
# Create Gradio interface
|
233 |
+
with gr.Blocks(title="RAG Embedding Model Tester", theme=gr.themes.Soft()) as demo:
|
234 |
+
gr.Markdown("# π§ͺ RAG Embedding Model Tester")
|
235 |
+
gr.Markdown("Test and compare different embedding models for RAG pipelines. Focus on relevance quality assessment.")
|
236 |
+
|
237 |
+
with gr.Row():
|
238 |
+
with gr.Column(scale=1):
|
239 |
+
gr.Markdown("## π Document Input")
|
240 |
+
|
241 |
+
document_input = gr.Textbox(
|
242 |
+
lines=15,
|
243 |
+
placeholder="Paste your document text here (Russian or English)...",
|
244 |
+
label="Document Text",
|
245 |
+
max_lines=20
|
246 |
+
)
|
247 |
+
|
248 |
+
file_input = gr.File(
|
249 |
+
file_types=[".txt", ".md"],
|
250 |
+
label="Or Upload Text File"
|
251 |
+
)
|
252 |
+
|
253 |
+
# Load file content to text box
|
254 |
+
file_input.change(
|
255 |
+
fn=load_file,
|
256 |
+
inputs=file_input,
|
257 |
+
outputs=document_input
|
258 |
+
)
|
259 |
+
|
260 |
+
embed_btn = gr.Button("π Start Embedding Process", variant="primary", size="lg")
|
261 |
+
embed_status = gr.Textbox(label="Processing Status", lines=8)
|
262 |
+
|
263 |
+
with gr.Column(scale=2):
|
264 |
+
gr.Markdown("## β Question & Comparison")
|
265 |
+
|
266 |
+
question_input = gr.Textbox(
|
267 |
+
placeholder="What question do you want to ask about the document?",
|
268 |
+
label="Your Question",
|
269 |
+
lines=2
|
270 |
+
)
|
271 |
+
|
272 |
+
compare_btn = gr.Button("π Compare All Models", variant="secondary", size="lg")
|
273 |
+
|
274 |
+
results_display = gr.Markdown(label="Comparison Results")
|
275 |
+
|
276 |
+
gr.Markdown("## π₯ Download Results")
|
277 |
+
report_download = gr.File(label="Download Test Report")
|
278 |
+
|
279 |
+
# Model info
|
280 |
+
with gr.Row():
|
281 |
+
gr.Markdown(f"""
|
282 |
+
## π€ Models to Test:
|
283 |
+
{', '.join([f"**{name}**" for name in EMBEDDING_MODELS.values()])}
|
284 |
+
|
285 |
+
## βοΈ Settings:
|
286 |
+
- **Chunk Size:** {CHUNK_SIZE} characters
|
287 |
+
- **Chunk Overlap:** {CHUNK_OVERLAP} characters
|
288 |
+
- **Top Results:** {TOP_K_RESULTS} chunks per model
|
289 |
+
- **Splitter:** RecursiveCharacterTextSplitter
|
290 |
+
""")
|
291 |
+
|
292 |
+
# Event handlers
|
293 |
+
embed_btn.click(
|
294 |
+
fn=process_embeddings,
|
295 |
+
inputs=document_input,
|
296 |
+
outputs=embed_status
|
297 |
+
)
|
298 |
+
|
299 |
+
def compare_and_download(question):
|
300 |
+
results_text, report_content = compare_embeddings(question)
|
301 |
+
|
302 |
+
# Create downloadable file
|
303 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
304 |
+
filename = f"rag_test_{timestamp}.txt"
|
305 |
+
|
306 |
+
# Save report to file-like object
|
307 |
+
report_file = io.StringIO(report_content)
|
308 |
+
report_file.name = filename
|
309 |
+
|
310 |
+
return results_text, gr.File.update(value=report_file.getvalue(), visible=True)
|
311 |
+
|
312 |
+
compare_btn.click(
|
313 |
+
fn=compare_and_download,
|
314 |
+
inputs=question_input,
|
315 |
+
outputs=[results_display, report_download]
|
316 |
+
)
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
demo.launch()
|