import re import numpy as np import json from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer from sklearn.cluster import AgglomerativeClustering from sklearn.metrics.pairwise import cosine_distances from langchain_google_genai import ChatGoogleGenerativeAI import os import gradio as gr tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") sentence_model = SentenceTransformer('all-MiniLM-L6-v2') max_tokens = 4000 def clean_text(text): text = re.sub(r'\[speaker_\d+\]', '', text) text = re.sub(r'\s+', ' ', text).strip() return text def split_text_with_modernbert_tokenizer(text): text = clean_text(text) rough_splits = re.split(r'(?<=[.!?])\s+', text) segments = [] current_segment = "" current_token_count = 0 for sentence in rough_splits: if not sentence.strip(): continue sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False)) if (current_token_count + sentence_tokens > 100 or re.search(r'[.!?]$', current_segment.strip())): if current_segment: segments.append(current_segment.strip()) current_segment = sentence current_token_count = sentence_tokens else: current_segment += " " + sentence if current_segment else sentence current_token_count += sentence_tokens if current_segment: segments.append(current_segment.strip()) refined_segments = [] for segment in segments: if len(segment.split()) < 3: if refined_segments: refined_segments[-1] += ' ' + segment else: refined_segments.append(segment) continue tokens = tokenizer.tokenize(segment) if len(tokens) < 50: refined_segments.append(segment) continue break_indices = [i for i, token in enumerate(tokens) if ('.' in token or ',' in token or '?' in token or '!' in token) and i < len(tokens) - 1] if not break_indices or break_indices[-1] < len(tokens) * 0.7: refined_segments.append(segment) continue mid_idx = break_indices[len(break_indices) // 2] first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1]) second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:]) refined_segments.append(first_half.strip()) refined_segments.append(second_half.strip()) return refined_segments def semantic_chunking(text): segments = split_text_with_modernbert_tokenizer(text) segment_embeddings = sentence_model.encode(segments) distances = cosine_distances(segment_embeddings) agg_clustering = AgglomerativeClustering( n_clusters=None, distance_threshold=1, metric='precomputed', linkage='average' ) clusters = agg_clustering.fit_predict(distances) # Group segments by cluster cluster_groups = {} for i, cluster_id in enumerate(clusters): if cluster_id not in cluster_groups: cluster_groups[cluster_id] = [] cluster_groups[cluster_id].append(segments[i]) chunks = [] for cluster_id in sorted(cluster_groups.keys()): cluster_segments = cluster_groups[cluster_id] current_chunk = [] current_token_count = 0 for segment in cluster_segments: segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True)) if segment_tokens > max_tokens: if current_chunk: chunks.append(" ".join(current_chunk)) current_chunk = [] current_token_count = 0 chunks.append(segment) continue if current_token_count + segment_tokens > max_tokens and current_chunk: chunks.append(" ".join(current_chunk)) current_chunk = [segment] current_token_count = segment_tokens else: current_chunk.append(segment) current_token_count += segment_tokens if current_chunk: chunks.append(" ".join(current_chunk)) if len(chunks) > 1: chunk_embeddings = sentence_model.encode(chunks) chunk_similarities = 1 - cosine_distances(chunk_embeddings) i = 0 while i < len(chunks) - 1: j = i + 1 if chunk_similarities[i, j] > 0.75: combined = chunks[i] + " " + chunks[j] combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True)) if combined_tokens <= max_tokens: # Merge chunks chunks[i] = combined chunks.pop(j) chunk_embeddings = sentence_model.encode(chunks) chunk_similarities = 1 - cosine_distances(chunk_embeddings) else: i += 1 else: i += 1 return chunks def analyze_segment_with_gemini(cluster_text, is_full_text=False): llm = ChatGoogleGenerativeAI( model="gemini-1.5-flash", temperature=0.7, max_tokens=None, timeout=None, max_retries=3 ) if is_full_text: prompt = f""" Analyze the following text (likely a transcript or document) and: 1. First, identify distinct segments or topics within the text 2. For each segment/topic you identify: - Provide a concise topic name (3-5 words) - List 3-5 key concepts discussed in that segment - Write a brief summary of that segment (3-5 sentences) - Create 5 quiz questions based DIRECTLY on the content in that segment For each quiz question: - Create one correct answer that comes DIRECTLY from the text - Create two plausible but incorrect answers - IMPORTANT: Ensure all answer options have similar length (± 3 words) - Ensure the correct answer is clearly indicated - The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one. Text: {cluster_text} Format your response as JSON with the following structure: {{ "segments": [ {{ "topic_name": "Name of segment 1", "key_concepts": ["concept1", "concept2", "concept3"], "summary": "Brief summary of this segment.", "quiz_questions": [ {{ "question": "Question text?", "options": [ {{ "text": "Option A", "correct": false }}, {{ "text": "Option B", "correct": true }}, {{ "text": "Option C", "correct": false }} ] }}, // More questions... ] }}, // More segments... ] }} """ else: prompt = f""" Analyze the following text segment and provide: 1. A concise topic name (3-5 words) 2. 3-5 key concepts discussed 3. A brief summary (6-7 sentences) 4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary) For each quiz question: - Create one correct answer that comes DIRECTLY from the text - Create two plausible but incorrect answers - IMPORTANT: Ensure all answer options have similar length (± 3 words) - Ensure the correct answer is clearly indicated - The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one. Text segment: {cluster_text} Format your response as JSON with the following structure: {{ "topic_name": "Name of the topic", "key_concepts": ["concept1", "concept2", "concept3"], "summary": "Brief summary of the text segment.", "quiz_questions": [ {{ "question": "Question text?", "options": [ {{ "text": "Option A", "correct": false }}, {{ "text": "Option B", "correct": true }}, {{ "text": "Option C", "correct": false }} ] }}, // More questions... ] }} """ response = llm.invoke(prompt) response_text = response.content try: json_match = re.search(r'\{[\s\S]*\}', response_text) if json_match: response_json = json.loads(json_match.group(0)) else: response_json = json.loads(response_text) return response_json except json.JSONDecodeError as e: print(f"Error parsing JSON response: {e}") print(f"Raw response: {response_text}") if is_full_text: return { "segments": [ { "topic_name": "JSON Parsing Error", "key_concepts": ["Error in response format"], "summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", "quiz_questions": [] } ] } else: return { "topic_name": "JSON Parsing Error", "key_concepts": ["Error in response format"], "summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", "quiz_questions": [] } def process_document_with_quiz(text): token_count = len(tokenizer.encode(text)) print(f"Text contains {token_count} tokens") if token_count < 12000: print("Text is short enough to analyze directly without text segmentation") full_analysis = analyze_segment_with_gemini(text, is_full_text=True) results = [] if "segments" in full_analysis: for i, segment in enumerate(full_analysis["segments"]): segment["segment_number"] = i + 1 segment["segment_text"] = "Segment identified by Gemini" results.append(segment) print(f"Gemini identified {len(results)} segments in the text") else: print("Unexpected response format from Gemini") results = [full_analysis] return results chunks = semantic_chunking(text) print(f"{len(chunks)} semantic chunks were found\n") results = [] for i, chunk in enumerate(chunks): print(f"Analyzing segment {i+1}/{len(chunks)}...") analysis = analyze_segment_with_gemini(chunk, is_full_text=False) analysis["segment_number"] = i + 1 analysis["segment_text"] = chunk results.append(analysis) print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}") return results def save_results_to_file(results, output_file="analysis_results.json"): with open(output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Results saved to {output_file}") def format_quiz_for_display(results): output = [] for segment_result in results: segment_num = segment_result["segment_number"] topic = segment_result["topic_name"] output.append(f"\n\n{'='*40}") output.append(f"SEGMENT {segment_num}: {topic}") output.append(f"{'='*40}\n") output.append("KEY CONCEPTS:") for concept in segment_result["key_concepts"]: output.append(f"• {concept}") output.append("\nSUMMARY:") output.append(segment_result["summary"]) output.append("\nQUIZ QUESTIONS:") for i, q in enumerate(segment_result["quiz_questions"]): output.append(f"\n{i+1}. {q['question']}") for j, option in enumerate(q['options']): letter = chr(97 + j).upper() correct_marker = " ✓" if option["correct"] else "" output.append(f" {letter}. {option['text']}{correct_marker}") return "\n".join(output) def analyze_document(document_text: str, api_key: str) -> str: os.environ["GOOGLE_API_KEY"] = api_key try: results = process_document_with_quiz(document_text) formatted_output = format_quiz_for_display(results) return formatted_output except Exception as e: return f"Error processing document: {str(e)}" with gr.Blocks(title="Quiz Generator ") as app: gr.Markdown("Quiz Generator") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Input Text", placeholder="Paste your document text here...", lines=10 ) api_key = gr.Textbox( label="Gemini API Key", placeholder="Enter your Gemini API key", type="password" ) analyze_btn = gr.Button("Analyze Document") with gr.Column(): output_results = gr.Textbox( label="Analysis Results", lines=20 ) analyze_btn.click( fn=analyze_document, inputs=[input_text, api_key], outputs=output_results ) if __name__ == "__main__": app.launch()