Spaces:
Sleeping
Sleeping
import re | |
import numpy as np | |
import json | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer | |
from sklearn.cluster import AgglomerativeClustering | |
from sklearn.metrics.pairwise import cosine_distances | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import os | |
import gradio as gr | |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") | |
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
max_tokens = 3000 | |
def clean_text(text): | |
text = re.sub(r'\[speaker_\d+\]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def split_text_with_modernbert_tokenizer(text): | |
text = clean_text(text) | |
rough_splits = re.split(r'(?<=[.!?])\s+', text) | |
segments = [] | |
current_segment = "" | |
current_token_count = 0 | |
for sentence in rough_splits: | |
if not sentence.strip(): | |
continue | |
sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False)) | |
if (current_token_count + sentence_tokens > 100 or | |
re.search(r'[.!?]$', current_segment.strip())): | |
if current_segment: | |
segments.append(current_segment.strip()) | |
current_segment = sentence | |
current_token_count = sentence_tokens | |
else: | |
current_segment += " " + sentence if current_segment else sentence | |
current_token_count += sentence_tokens | |
if current_segment: | |
segments.append(current_segment.strip()) | |
refined_segments = [] | |
for segment in segments: | |
if len(segment.split()) < 3: | |
if refined_segments: | |
refined_segments[-1] += ' ' + segment | |
else: | |
refined_segments.append(segment) | |
continue | |
tokens = tokenizer.tokenize(segment) | |
if len(tokens) < 50: | |
refined_segments.append(segment) | |
continue | |
break_indices = [i for i, token in enumerate(tokens) | |
if ('.' in token or ',' in token or '?' in token or '!' in token) | |
and i < len(tokens) - 1] | |
if not break_indices or break_indices[-1] < len(tokens) * 0.7: | |
refined_segments.append(segment) | |
continue | |
mid_idx = break_indices[len(break_indices) // 2] | |
first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1]) | |
second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:]) | |
refined_segments.append(first_half.strip()) | |
refined_segments.append(second_half.strip()) | |
return refined_segments | |
def semantic_chunking(text): | |
segments = split_text_with_modernbert_tokenizer(text) | |
segment_embeddings = sentence_model.encode(segments) | |
distances = cosine_distances(segment_embeddings) | |
agg_clustering = AgglomerativeClustering( | |
n_clusters=None, | |
distance_threshold=1, | |
metric='precomputed', | |
linkage='average' | |
) | |
clusters = agg_clustering.fit_predict(distances) | |
# Group segments by cluster | |
cluster_groups = {} | |
for i, cluster_id in enumerate(clusters): | |
if cluster_id not in cluster_groups: | |
cluster_groups[cluster_id] = [] | |
cluster_groups[cluster_id].append(segments[i]) | |
chunks = [] | |
for cluster_id in sorted(cluster_groups.keys()): | |
cluster_segments = cluster_groups[cluster_id] | |
current_chunk = [] | |
current_token_count = 0 | |
for segment in cluster_segments: | |
segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True)) | |
if segment_tokens > max_tokens: | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [] | |
current_token_count = 0 | |
chunks.append(segment) | |
continue | |
if current_token_count + segment_tokens > max_tokens and current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [segment] | |
current_token_count = segment_tokens | |
else: | |
current_chunk.append(segment) | |
current_token_count += segment_tokens | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
if len(chunks) > 1: | |
chunk_embeddings = sentence_model.encode(chunks) | |
chunk_similarities = 1 - cosine_distances(chunk_embeddings) | |
i = 0 | |
while i < len(chunks) - 1: | |
j = i + 1 | |
if chunk_similarities[i, j] > 0.75: | |
combined = chunks[i] + " " + chunks[j] | |
combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True)) | |
if combined_tokens <= max_tokens: | |
# Merge chunks | |
chunks[i] = combined | |
chunks.pop(j) | |
chunk_embeddings = sentence_model.encode(chunks) | |
chunk_similarities = 1 - cosine_distances(chunk_embeddings) | |
else: | |
i += 1 | |
else: | |
i += 1 | |
return chunks | |
def analyze_segment_with_gemini(cluster_text, is_full_text=False): | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0.7, | |
max_tokens=None, | |
timeout=None, | |
max_retries=3 | |
) | |
if is_full_text: | |
prompt = f""" | |
FIRST ASSESS THE TEXT: | |
- Check if it's primarily self-introduction, biographical information, or conclusion | |
- Check if it's too short or lacks meaningful content (less than 100 words of substance) | |
- If either case is true, respond with a simple JSON: {{"status": "insufficient", "reason": "Brief explanation"}} | |
Analyze the following text (likely a transcript or document) and: | |
1. First, do text segmentation and identify DISTINCT key topics within the text | |
2. For each segment/topic you identify: | |
- Provide a SPECIFIC and UNIQUE topic name (3-5 words) that clearly differentiates it from other segments | |
- List 3-5 key concepts discussed in that segment | |
- Write a brief summary of that segment (3-5 sentences) | |
- Create 5 quiz questions based DIRECTLY on the content in that segment | |
For each quiz question: | |
- Create one correct answer that comes DIRECTLY from the text | |
- Create two plausible but incorrect answers | |
- IMPORTANT: Ensure all answer options have similar length (± 3 words) | |
- Ensure the correct answer is clearly indicated | |
Text: | |
{cluster_text} | |
Format your response as JSON with the following structure: | |
{{ | |
"segments": [ | |
{{ | |
"topic_name": "Name of segment 1", | |
"key_concepts": ["concept1", "concept2", "concept3"], | |
"summary": "Brief summary of this segment.", | |
"quiz_questions": [ | |
{{ | |
"question": "Question text?", | |
"options": [ | |
{{ | |
"text": "Option A", | |
"correct": false | |
}}, | |
{{ | |
"text": "Option B", | |
"correct": true | |
}}, | |
{{ | |
"text": "Option C", | |
"correct": false | |
}} | |
] | |
}}, | |
// More questions... | |
] | |
}}, | |
// More segments... | |
] | |
}} | |
OR if the text is just introductory, concluding, or insufficient: | |
{{ | |
"status": "insufficient", | |
"reason": "Brief explanation of why (e.g., 'Text is primarily self-introduction', 'Text is too short', etc.)" | |
}} | |
""" | |
else: | |
prompt = f""" | |
FIRST ASSESS THE TEXT: | |
- Check if it's primarily self-introduction, biographical information, or conclusion | |
- Check if it's too short or lacks meaningful content (less than 100 words of substance) | |
- If either case is true, respond with a simple JSON: {{"status": "insufficient", "reason": "Brief explanation"}} | |
Analyze the following text segment and provide: | |
1. A SPECIFIC and DESCRIPTIVE topic name (3-5 words) that precisely captures the main focus | |
2. 3-5 key concepts discussed | |
3. A brief summary (6-7 sentences) | |
4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary) | |
For each quiz question: | |
- Create one correct answer that comes DIRECTLY from the text | |
- Create two plausible but incorrect answers | |
- IMPORTANT and STRICTLY: Ensure all answer options have similar length (± 3 words) | |
- Ensure the correct answer is clearly indicated | |
Text segment: | |
{cluster_text} | |
Format your response as JSON with the following structure: | |
{{ | |
"topic_name": "Name of the topic", | |
"key_concepts": ["concept1", "concept2", "concept3"], | |
"summary": "Brief summary of the text segment.", | |
"quiz_questions": [ | |
{{ | |
"question": "Question text?", | |
"options": [ | |
{{ | |
"text": "Option A", | |
"correct": false | |
}}, | |
{{ | |
"text": "Option B", | |
"correct": true | |
}}, | |
{{ | |
"text": "Option C", | |
"correct": false | |
}} | |
] | |
}}, | |
// More questions... | |
] | |
}} | |
OR if the text is just introductory, concluding, or insufficient: | |
{{ | |
"status": "insufficient", | |
"reason": "Brief explanation of why (e.g., 'Text is primarily self-introduction', 'Text is too short', etc.)" | |
}} | |
""" | |
response = llm.invoke(prompt) | |
response_text = response.content | |
try: | |
json_match = re.search(r'\{[\s\S]*\}', response_text) | |
if json_match: | |
response_json = json.loads(json_match.group(0)) | |
else: | |
response_json = json.loads(response_text) | |
return response_json | |
except json.JSONDecodeError as e: | |
print(f"Error parsing JSON response: {e}") | |
print(f"Raw response: {response_text}") | |
if is_full_text: | |
return { | |
"segments": [ | |
{ | |
"topic_name": "JSON Parsing Error", | |
"key_concepts": ["Error in response format"], | |
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", | |
"quiz_questions": [] | |
} | |
] | |
} | |
else: | |
return { | |
"topic_name": "JSON Parsing Error", | |
"key_concepts": ["Error in response format"], | |
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", | |
"quiz_questions": [] | |
} | |
def process_document_with_quiz(text): | |
token_count = len(tokenizer.encode(text)) | |
print(f"Text contains {token_count} tokens") | |
if token_count < 8000: | |
print("Text is short enough to analyze directly without text segmentation") | |
full_analysis = analyze_segment_with_gemini(text, is_full_text=True) | |
results = [] | |
if "segments" in full_analysis: | |
for i, segment in enumerate(full_analysis["segments"]): | |
segment["segment_number"] = i + 1 | |
segment["segment_text"] = "Segment identified by Gemini" | |
results.append(segment) | |
print(f"Gemini identified {len(results)} segments in the text") | |
else: | |
print("Unexpected response format from Gemini") | |
results = [full_analysis] | |
return results | |
chunks = semantic_chunking(text) | |
print(f"{len(chunks)} semantic chunks were found\n") | |
results = [] | |
for i, chunk in enumerate(chunks): | |
print(f"Analyzing segment {i+1}/{len(chunks)}...") | |
analysis = analyze_segment_with_gemini(chunk, is_full_text=False) | |
analysis["segment_number"] = i + 1 | |
analysis["segment_text"] = chunk | |
results.append(analysis) | |
print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}") | |
return results | |
def save_results_to_file(results, output_file="analysis_results.json"): | |
with open(output_file, "w", encoding="utf-8") as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
print(f"Results saved to {output_file}") | |
def format_quiz_for_display(results): | |
output = [] | |
for segment_result in results: | |
segment_num = segment_result["segment_number"] | |
topic = segment_result["topic_name"] | |
output.append(f"\n\n{'='*40}") | |
output.append(f"SEGMENT {segment_num}: {topic}") | |
output.append(f"{'='*40}\n") | |
output.append("KEY CONCEPTS:") | |
for concept in segment_result["key_concepts"]: | |
output.append(f"• {concept}") | |
output.append("\nSUMMARY:") | |
output.append(segment_result["summary"]) | |
output.append("\nQUIZ QUESTIONS:") | |
for i, q in enumerate(segment_result["quiz_questions"]): | |
output.append(f"\n{i+1}. {q['question']}") | |
for j, option in enumerate(q['options']): | |
letter = chr(97 + j).upper() | |
correct_marker = " ✓" if option["correct"] else "" | |
output.append(f" {letter}. {option['text']}{correct_marker}") | |
return "\n".join(output) | |
def analyze_document(document_text: str, api_key: str) -> tuple: | |
os.environ["GOOGLE_API_KEY"] = api_key | |
try: | |
results = process_document_with_quiz(document_text) | |
formatted_output = format_quiz_for_display(results) | |
json_path = "analysis_results.json" | |
txt_path = "analysis_results.txt" | |
with open(json_path, "w", encoding="utf-8") as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
with open(txt_path, "w", encoding="utf-8") as f: | |
f.write(formatted_output) | |
return formatted_output, json_path, txt_path | |
except Exception as e: | |
error_msg = f"Error processing document: {str(e)}" | |
return error_msg, None, None | |
with gr.Blocks(title="Quiz Generator") as app: | |
gr.Markdown("# Quiz Generator") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Paste your document text here...", | |
lines=10 | |
) | |
api_key = gr.Textbox( | |
label="Gemini API Key", | |
placeholder="Enter your Gemini API key", | |
type="password" | |
) | |
analyze_btn = gr.Button("Analyze Document") | |
with gr.Column(): | |
output_results = gr.Textbox( | |
label="Analysis Results", | |
lines=20 | |
) | |
json_file_output = gr.File(label="Download JSON") | |
txt_file_output = gr.File(label="Download TXT") | |
analyze_btn.click( | |
fn=analyze_document, | |
inputs=[input_text, api_key], | |
outputs=[output_results, json_file_output, txt_file_output] | |
) | |
if __name__ == "__main__": | |
app.launch() |