MrSimple01 commited on
Commit
b29c0f7
·
verified ·
1 Parent(s): e3f6f16

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import json
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer
6
+ from sklearn.cluster import AgglomerativeClustering
7
+ from sklearn.metrics.pairwise import cosine_distances
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ import os
10
+ import gradio as gr
11
+
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
14
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
15
+ max_tokens = 4000
16
+
17
+ def clean_text(text):
18
+ text = re.sub(r'\[speaker_\d+\]', '', text)
19
+ text = re.sub(r'\s+', ' ', text).strip()
20
+ return text
21
+
22
+ def split_text_with_modernbert_tokenizer(text):
23
+ text = clean_text(text)
24
+ rough_splits = re.split(r'(?<=[.!?])\s+', text)
25
+
26
+ segments = []
27
+ current_segment = ""
28
+ current_token_count = 0
29
+
30
+ for sentence in rough_splits:
31
+ if not sentence.strip():
32
+ continue
33
+
34
+ sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False))
35
+ if (current_token_count + sentence_tokens > 100 or
36
+ re.search(r'[.!?]$', current_segment.strip())):
37
+ if current_segment:
38
+ segments.append(current_segment.strip())
39
+ current_segment = sentence
40
+ current_token_count = sentence_tokens
41
+ else:
42
+ current_segment += " " + sentence if current_segment else sentence
43
+ current_token_count += sentence_tokens
44
+
45
+ if current_segment:
46
+ segments.append(current_segment.strip())
47
+
48
+ refined_segments = []
49
+
50
+ for segment in segments:
51
+ if len(segment.split()) < 3:
52
+ if refined_segments:
53
+ refined_segments[-1] += ' ' + segment
54
+ else:
55
+ refined_segments.append(segment)
56
+ continue
57
+
58
+ tokens = tokenizer.tokenize(segment)
59
+
60
+ if len(tokens) < 50:
61
+ refined_segments.append(segment)
62
+ continue
63
+
64
+ break_indices = [i for i, token in enumerate(tokens)
65
+ if ('.' in token or ',' in token or '?' in token or '!' in token)
66
+ and i < len(tokens) - 1]
67
+
68
+ if not break_indices or break_indices[-1] < len(tokens) * 0.7:
69
+ refined_segments.append(segment)
70
+ continue
71
+
72
+ mid_idx = break_indices[len(break_indices) // 2]
73
+ first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1])
74
+ second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:])
75
+
76
+ refined_segments.append(first_half.strip())
77
+ refined_segments.append(second_half.strip())
78
+
79
+ return refined_segments
80
+
81
+ def semantic_chunking(text):
82
+ segments = split_text_with_modernbert_tokenizer(text)
83
+ segment_embeddings = sentence_model.encode(segments)
84
+
85
+ distances = cosine_distances(segment_embeddings)
86
+
87
+ agg_clustering = AgglomerativeClustering(
88
+ n_clusters=None,
89
+ distance_threshold=1,
90
+ metric='precomputed',
91
+ linkage='average'
92
+ )
93
+ clusters = agg_clustering.fit_predict(distances)
94
+
95
+ # Group segments by cluster
96
+ cluster_groups = {}
97
+ for i, cluster_id in enumerate(clusters):
98
+ if cluster_id not in cluster_groups:
99
+ cluster_groups[cluster_id] = []
100
+ cluster_groups[cluster_id].append(segments[i])
101
+
102
+ chunks = []
103
+ for cluster_id in sorted(cluster_groups.keys()):
104
+ cluster_segments = cluster_groups[cluster_id]
105
+
106
+ current_chunk = []
107
+ current_token_count = 0
108
+
109
+ for segment in cluster_segments:
110
+ segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True))
111
+ if segment_tokens > max_tokens:
112
+ if current_chunk:
113
+ chunks.append(" ".join(current_chunk))
114
+ current_chunk = []
115
+ current_token_count = 0
116
+ chunks.append(segment)
117
+ continue
118
+
119
+ if current_token_count + segment_tokens > max_tokens and current_chunk:
120
+ chunks.append(" ".join(current_chunk))
121
+ current_chunk = [segment]
122
+ current_token_count = segment_tokens
123
+ else:
124
+ current_chunk.append(segment)
125
+ current_token_count += segment_tokens
126
+
127
+ if current_chunk:
128
+ chunks.append(" ".join(current_chunk))
129
+
130
+ if len(chunks) > 1:
131
+ chunk_embeddings = sentence_model.encode(chunks)
132
+ chunk_similarities = 1 - cosine_distances(chunk_embeddings)
133
+
134
+ i = 0
135
+ while i < len(chunks) - 1:
136
+ j = i + 1
137
+ if chunk_similarities[i, j] > 0.75:
138
+ combined = chunks[i] + " " + chunks[j]
139
+ combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True))
140
+
141
+ if combined_tokens <= max_tokens:
142
+ # Merge chunks
143
+ chunks[i] = combined
144
+ chunks.pop(j)
145
+ chunk_embeddings = sentence_model.encode(chunks)
146
+ chunk_similarities = 1 - cosine_distances(chunk_embeddings)
147
+ else:
148
+ i += 1
149
+ else:
150
+ i += 1
151
+
152
+ return chunks
153
+
154
+ def analyze_segment_with_gemini(cluster_text, is_full_text=False):
155
+ llm = ChatGoogleGenerativeAI(
156
+ model="gemini-1.5-flash",
157
+ temperature=0.7,
158
+ max_tokens=None,
159
+ timeout=None,
160
+ max_retries=3
161
+ )
162
+
163
+ if is_full_text:
164
+ prompt = f"""
165
+ Analyze the following text (likely a transcript or document) and:
166
+
167
+ 1. First, identify distinct segments or topics within the text
168
+ 2. For each segment/topic you identify:
169
+ - Provide a concise topic name (3-5 words)
170
+ - List 3-5 key concepts discussed in that segment
171
+ - Write a brief summary of that segment (3-5 sentences)
172
+ - Create 5 quiz questions based DIRECTLY on the content in that segment
173
+
174
+ For each quiz question:
175
+ - Create one correct answer that comes DIRECTLY from the text
176
+ - Create two plausible but incorrect answers
177
+ - IMPORTANT: Ensure all answer options have similar length (± 3 words)
178
+ - Ensure the correct answer is clearly indicated
179
+ - The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one.
180
+
181
+ Text:
182
+ {cluster_text}
183
+
184
+ Format your response as JSON with the following structure:
185
+ {{
186
+ "segments": [
187
+ {{
188
+ "topic_name": "Name of segment 1",
189
+ "key_concepts": ["concept1", "concept2", "concept3"],
190
+ "summary": "Brief summary of this segment.",
191
+ "quiz_questions": [
192
+ {{
193
+ "question": "Question text?",
194
+ "options": [
195
+ {{
196
+ "text": "Option A",
197
+ "correct": false
198
+ }},
199
+ {{
200
+ "text": "Option B",
201
+ "correct": true
202
+ }},
203
+ {{
204
+ "text": "Option C",
205
+ "correct": false
206
+ }}
207
+ ]
208
+ }},
209
+ // More questions...
210
+ ]
211
+ }},
212
+ // More segments...
213
+ ]
214
+ }}
215
+ """
216
+ else:
217
+ prompt = f"""
218
+ Analyze the following text segment and provide:
219
+ 1. A concise topic name (3-5 words)
220
+ 2. 3-5 key concepts discussed
221
+ 3. A brief summary (6-7 sentences)
222
+ 4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary)
223
+
224
+ For each quiz question:
225
+ - Create one correct answer that comes DIRECTLY from the text
226
+ - Create two plausible but incorrect answers
227
+ - IMPORTANT: Ensure all answer options have similar length (± 3 words)
228
+ - Ensure the correct answer is clearly indicated
229
+ - The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one.
230
+
231
+ Text segment:
232
+ {cluster_text}
233
+
234
+ Format your response as JSON with the following structure:
235
+ {{
236
+ "topic_name": "Name of the topic",
237
+ "key_concepts": ["concept1", "concept2", "concept3"],
238
+ "summary": "Brief summary of the text segment.",
239
+ "quiz_questions": [
240
+ {{
241
+ "question": "Question text?",
242
+ "options": [
243
+ {{
244
+ "text": "Option A",
245
+ "correct": false
246
+ }},
247
+ {{
248
+ "text": "Option B",
249
+ "correct": true
250
+ }},
251
+ {{
252
+ "text": "Option C",
253
+ "correct": false
254
+ }}
255
+ ]
256
+ }},
257
+ // More questions...
258
+ ]
259
+ }}
260
+ """
261
+
262
+ response = llm.invoke(prompt)
263
+
264
+ response_text = response.content
265
+
266
+ try:
267
+ json_match = re.search(r'\{[\s\S]*\}', response_text)
268
+ if json_match:
269
+ response_json = json.loads(json_match.group(0))
270
+ else:
271
+ response_json = json.loads(response_text)
272
+
273
+ return response_json
274
+ except json.JSONDecodeError as e:
275
+ print(f"Error parsing JSON response: {e}")
276
+ print(f"Raw response: {response_text}")
277
+
278
+ if is_full_text:
279
+ return {
280
+ "segments": [
281
+ {
282
+ "topic_name": "JSON Parsing Error",
283
+ "key_concepts": ["Error in response format"],
284
+ "summary": f"Could not parse the API response. Raw text: {response_text[:200]}...",
285
+ "quiz_questions": []
286
+ }
287
+ ]
288
+ }
289
+ else:
290
+ return {
291
+ "topic_name": "JSON Parsing Error",
292
+ "key_concepts": ["Error in response format"],
293
+ "summary": f"Could not parse the API response. Raw text: {response_text[:200]}...",
294
+ "quiz_questions": []
295
+ }
296
+
297
+
298
+
299
+ def process_document_with_quiz(text):
300
+ token_count = len(tokenizer.encode(text))
301
+ print(f"Text contains {token_count} tokens")
302
+
303
+ if token_count < 12000:
304
+ print("Text is short enough to analyze directly without text segmentation")
305
+ full_analysis = analyze_segment_with_gemini(text, is_full_text=True)
306
+
307
+ results = []
308
+
309
+ if "segments" in full_analysis:
310
+ for i, segment in enumerate(full_analysis["segments"]):
311
+ segment["segment_number"] = i + 1
312
+ segment["segment_text"] = "Segment identified by Gemini"
313
+ results.append(segment)
314
+
315
+ print(f"Gemini identified {len(results)} segments in the text")
316
+ else:
317
+ print("Unexpected response format from Gemini")
318
+ results = [full_analysis]
319
+
320
+ return results
321
+
322
+ chunks = semantic_chunking(text)
323
+ print(f"{len(chunks)} semantic chunks were found\n")
324
+
325
+ results = []
326
+
327
+ for i, chunk in enumerate(chunks):
328
+ print(f"Analyzing segment {i+1}/{len(chunks)}...")
329
+ analysis = analyze_segment_with_gemini(chunk, is_full_text=False)
330
+ analysis["segment_number"] = i + 1
331
+ analysis["segment_text"] = chunk
332
+
333
+ results.append(analysis)
334
+
335
+ print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}")
336
+
337
+ return results
338
+
339
+ def save_results_to_file(results, output_file="analysis_results.json"):
340
+ with open(output_file, "w", encoding="utf-8") as f:
341
+ json.dump(results, f, indent=2, ensure_ascii=False)
342
+
343
+ print(f"Results saved to {output_file}")
344
+
345
+
346
+ def format_quiz_for_display(results):
347
+ output = []
348
+
349
+ for segment_result in results:
350
+ segment_num = segment_result["segment_number"]
351
+ topic = segment_result["topic_name"]
352
+
353
+ output.append(f"\n\n{'='*40}")
354
+ output.append(f"SEGMENT {segment_num}: {topic}")
355
+ output.append(f"{'='*40}\n")
356
+
357
+ output.append("KEY CONCEPTS:")
358
+ for concept in segment_result["key_concepts"]:
359
+ output.append(f"• {concept}")
360
+
361
+ output.append("\nSUMMARY:")
362
+ output.append(segment_result["summary"])
363
+
364
+ output.append("\nQUIZ QUESTIONS:")
365
+ for i, q in enumerate(segment_result["quiz_questions"]):
366
+ output.append(f"\n{i+1}. {q['question']}")
367
+
368
+ for j, option in enumerate(q['options']):
369
+ letter = chr(97 + j).upper()
370
+ correct_marker = " ✓" if option["correct"] else ""
371
+ output.append(f" {letter}. {option['text']}{correct_marker}")
372
+
373
+ return "\n".join(output)
374
+
375
+
376
+
377
+ def analyze_document(document_text: str, api_key: str) -> str:
378
+ os.environ["GOOGLE_API_KEY"] = api_key
379
+ try:
380
+ results = process_document_with_quiz(document_text)
381
+ formatted_output = format_quiz_for_display(results)
382
+ return formatted_output
383
+ except Exception as e:
384
+ return f"Error processing document: {str(e)}"
385
+
386
+ with gr.Blocks(title="Quiz Generator ") as app:
387
+ gr.Markdown("Quiz Generator")
388
+
389
+ with gr.Row():
390
+ with gr.Column():
391
+ input_text = gr.Textbox(
392
+ label="Input Text",
393
+ placeholder="Paste your document text here...",
394
+ lines=10
395
+ )
396
+
397
+ api_key = gr.Textbox(
398
+ label="Gemini API Key",
399
+ placeholder="Enter your Gemini API key",
400
+ type="password"
401
+ )
402
+
403
+ analyze_btn = gr.Button("Analyze Document")
404
+
405
+ with gr.Column():
406
+ output_results = gr.Textbox(
407
+ label="Analysis Results",
408
+ lines=20
409
+ )
410
+
411
+ analyze_btn.click(
412
+ fn=analyze_document,
413
+ inputs=[input_text, api_key],
414
+ outputs=output_results
415
+ )
416
+
417
+ if __name__ == "__main__":
418
+ app.launch()