QuizGenerator / app.py
MrSimple01's picture
Update app.py
2ece4c8 verified
raw
history blame
16.8 kB
import re
import numpy as np
import json
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_distances
from langchain_google_genai import ChatGoogleGenerativeAI
import os
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
max_tokens = 3000
def clean_text(text):
text = re.sub(r'\[speaker_\d+\]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def split_text_with_modernbert_tokenizer(text):
text = clean_text(text)
rough_splits = re.split(r'(?<=[.!?])\s+', text)
segments = []
current_segment = ""
current_token_count = 0
for sentence in rough_splits:
if not sentence.strip():
continue
sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False))
if (current_token_count + sentence_tokens > 100 or
re.search(r'[.!?]$', current_segment.strip())):
if current_segment:
segments.append(current_segment.strip())
current_segment = sentence
current_token_count = sentence_tokens
else:
current_segment += " " + sentence if current_segment else sentence
current_token_count += sentence_tokens
if current_segment:
segments.append(current_segment.strip())
refined_segments = []
for segment in segments:
if len(segment.split()) < 3:
if refined_segments:
refined_segments[-1] += ' ' + segment
else:
refined_segments.append(segment)
continue
tokens = tokenizer.tokenize(segment)
if len(tokens) < 50:
refined_segments.append(segment)
continue
break_indices = [i for i, token in enumerate(tokens)
if ('.' in token or ',' in token or '?' in token or '!' in token)
and i < len(tokens) - 1]
if not break_indices or break_indices[-1] < len(tokens) * 0.7:
refined_segments.append(segment)
continue
mid_idx = break_indices[len(break_indices) // 2]
first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1])
second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:])
refined_segments.append(first_half.strip())
refined_segments.append(second_half.strip())
return refined_segments
def semantic_chunking(text):
segments = split_text_with_modernbert_tokenizer(text)
segment_embeddings = sentence_model.encode(segments)
distances = cosine_distances(segment_embeddings)
agg_clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=1,
metric='precomputed',
linkage='average'
)
clusters = agg_clustering.fit_predict(distances)
# Group segments by cluster
cluster_groups = {}
for i, cluster_id in enumerate(clusters):
if cluster_id not in cluster_groups:
cluster_groups[cluster_id] = []
cluster_groups[cluster_id].append(segments[i])
chunks = []
for cluster_id in sorted(cluster_groups.keys()):
cluster_segments = cluster_groups[cluster_id]
current_chunk = []
current_token_count = 0
for segment in cluster_segments:
segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True))
if segment_tokens > max_tokens:
if current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_token_count = 0
chunks.append(segment)
continue
if current_token_count + segment_tokens > max_tokens and current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = [segment]
current_token_count = segment_tokens
else:
current_chunk.append(segment)
current_token_count += segment_tokens
if current_chunk:
chunks.append(" ".join(current_chunk))
if len(chunks) > 1:
chunk_embeddings = sentence_model.encode(chunks)
chunk_similarities = 1 - cosine_distances(chunk_embeddings)
i = 0
while i < len(chunks) - 1:
j = i + 1
if chunk_similarities[i, j] > 0.75:
combined = chunks[i] + " " + chunks[j]
combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True))
if combined_tokens <= max_tokens:
# Merge chunks
chunks[i] = combined
chunks.pop(j)
chunk_embeddings = sentence_model.encode(chunks)
chunk_similarities = 1 - cosine_distances(chunk_embeddings)
else:
i += 1
else:
i += 1
return chunks
def analyze_segment_with_gemini(cluster_text, is_full_text=False):
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0.7,
max_tokens=None,
timeout=None,
max_retries=3
)
if is_full_text:
prompt = f"""
FIRST ASSESS THE TEXT:
- Check if it's primarily self-introduction, biographical information, or conclusion
- Check if it's too short or lacks meaningful content (less than 100 words of substance)
- If either case is true, respond with a simple JSON: {{"status": "insufficient", "reason": "Brief explanation"}}
Analyze the following text (likely a transcript or document) and:
1. First, do text segmentation and identify DISTINCT key topics within the text
2. For each segment/topic you identify:
- Provide a SPECIFIC and UNIQUE topic name (3-5 words) that clearly differentiates it from other segments
- List 3-5 key concepts discussed in that segment
- Write a brief summary of that segment (3-5 sentences)
- Create 5 quiz questions based DIRECTLY on the content in that segment
For each quiz question:
- Create one correct answer that comes DIRECTLY from the text
- Create two plausible but incorrect answers
- IMPORTANT: Ensure all answer options have similar length (± 3 words)
- Ensure the correct answer is clearly indicated
Text:
{cluster_text}
Format your response as JSON with the following structure:
{{
"segments": [
{{
"topic_name": "Name of segment 1",
"key_concepts": ["concept1", "concept2", "concept3"],
"summary": "Brief summary of this segment.",
"quiz_questions": [
{{
"question": "Question text?",
"options": [
{{
"text": "Option A",
"correct": false
}},
{{
"text": "Option B",
"correct": true
}},
{{
"text": "Option C",
"correct": false
}}
]
}},
// More questions...
]
}},
// More segments...
]
}}
OR if the text is just introductory, concluding, or insufficient:
{{
"status": "insufficient",
"reason": "Brief explanation of why (e.g., 'Text is primarily self-introduction', 'Text is too short', etc.)"
}}
"""
else:
prompt = f"""
FIRST ASSESS THE TEXT:
- Check if it's primarily self-introduction, biographical information, or conclusion
- Check if it's too short or lacks meaningful content (less than 100 words of substance)
- If either case is true, respond with a simple JSON: {{"status": "insufficient", "reason": "Brief explanation"}}
Analyze the following text segment and provide:
1. A SPECIFIC and DESCRIPTIVE topic name (3-5 words) that precisely captures the main focus
2. 3-5 key concepts discussed
3. A brief summary (6-7 sentences)
4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary)
For each quiz question:
- Create one correct answer that comes DIRECTLY from the text
- Create two plausible but incorrect answers
- IMPORTANT and STRICTLY: Ensure all answer options have similar length (± 3 words)
- Ensure the correct answer is clearly indicated
Text segment:
{cluster_text}
Format your response as JSON with the following structure:
{{
"topic_name": "Name of the topic",
"key_concepts": ["concept1", "concept2", "concept3"],
"summary": "Brief summary of the text segment.",
"quiz_questions": [
{{
"question": "Question text?",
"options": [
{{
"text": "Option A",
"correct": false
}},
{{
"text": "Option B",
"correct": true
}},
{{
"text": "Option C",
"correct": false
}}
]
}},
// More questions...
]
}}
OR if the text is just introductory, concluding, or insufficient:
{{
"status": "insufficient",
"reason": "Brief explanation of why (e.g., 'Text is primarily self-introduction', 'Text is too short', etc.)"
}}
"""
response = llm.invoke(prompt)
response_text = response.content
try:
json_match = re.search(r'\{[\s\S]*\}', response_text)
if json_match:
response_json = json.loads(json_match.group(0))
else:
response_json = json.loads(response_text)
return response_json
except json.JSONDecodeError as e:
print(f"Error parsing JSON response: {e}")
print(f"Raw response: {response_text}")
if is_full_text:
return {
"segments": [
{
"topic_name": "JSON Parsing Error",
"key_concepts": ["Error in response format"],
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...",
"quiz_questions": []
}
]
}
else:
return {
"topic_name": "JSON Parsing Error",
"key_concepts": ["Error in response format"],
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...",
"quiz_questions": []
}
def process_document_with_quiz(text):
token_count = len(tokenizer.encode(text))
print(f"Text contains {token_count} tokens")
if token_count < 8000:
print("Text is short enough to analyze directly without text segmentation")
full_analysis = analyze_segment_with_gemini(text, is_full_text=True)
results = []
if "segments" in full_analysis:
for i, segment in enumerate(full_analysis["segments"]):
segment["segment_number"] = i + 1
segment["segment_text"] = "Segment identified by Gemini"
results.append(segment)
print(f"Gemini identified {len(results)} segments in the text")
else:
print("Unexpected response format from Gemini")
results = [full_analysis]
return results
chunks = semantic_chunking(text)
print(f"{len(chunks)} semantic chunks were found\n")
results = []
for i, chunk in enumerate(chunks):
print(f"Analyzing segment {i+1}/{len(chunks)}...")
analysis = analyze_segment_with_gemini(chunk, is_full_text=False)
analysis["segment_number"] = i + 1
analysis["segment_text"] = chunk
results.append(analysis)
print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}")
return results
def save_results_to_file(results, output_file="analysis_results.json"):
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results saved to {output_file}")
def format_quiz_for_display(results):
output = []
for segment_result in results:
segment_num = segment_result["segment_number"]
topic = segment_result["topic_name"]
output.append(f"\n\n{'='*40}")
output.append(f"SEGMENT {segment_num}: {topic}")
output.append(f"{'='*40}\n")
output.append("KEY CONCEPTS:")
for concept in segment_result["key_concepts"]:
output.append(f"• {concept}")
output.append("\nSUMMARY:")
output.append(segment_result["summary"])
output.append("\nQUIZ QUESTIONS:")
for i, q in enumerate(segment_result["quiz_questions"]):
output.append(f"\n{i+1}. {q['question']}")
for j, option in enumerate(q['options']):
letter = chr(97 + j).upper()
correct_marker = " ✓" if option["correct"] else ""
output.append(f" {letter}. {option['text']}{correct_marker}")
return "\n".join(output)
def analyze_document(document_text: str, api_key: str) -> tuple:
os.environ["GOOGLE_API_KEY"] = api_key
try:
results = process_document_with_quiz(document_text)
formatted_output = format_quiz_for_display(results)
json_path = "analysis_results.json"
txt_path = "analysis_results.txt"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
with open(txt_path, "w", encoding="utf-8") as f:
f.write(formatted_output)
return formatted_output, json_path, txt_path
except Exception as e:
error_msg = f"Error processing document: {str(e)}"
return error_msg, None, None
with gr.Blocks(title="Quiz Generator") as app:
gr.Markdown("# Quiz Generator")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Paste your document text here...",
lines=10
)
api_key = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API key",
type="password"
)
analyze_btn = gr.Button("Analyze Document")
with gr.Column():
output_results = gr.Textbox(
label="Analysis Results",
lines=20
)
json_file_output = gr.File(label="Download JSON")
txt_file_output = gr.File(label="Download TXT")
analyze_btn.click(
fn=analyze_document,
inputs=[input_text, api_key],
outputs=[output_results, json_file_output, txt_file_output]
)
if __name__ == "__main__":
app.launch()