import gradio as gr
import sqlite3
import json
import os
from datetime import datetime
import torch
import nltk
from transformers import (
T5Tokenizer,
T5ForConditionalGeneration,
ElectraTokenizer,
ElectraForTokenClassification
)
import torch.nn as nn
from tqdm import tqdm
# Download NLTK data
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
class HuggingFaceT5GEDInference:
def __init__(self, model_name="Zlovoblachko/REAlEC_2step_model_testing",
ged_model_name="Zlovoblachko/11tag-electra-grammar-stage2", device=None):
"""
Initialize the inference class for T5-GED model from HuggingFace
Args:
model_name: HuggingFace model name/path for the T5-GED model
ged_model_name: HuggingFace model name/path for the GED model
device: Device to run inference on (cuda/cpu)
"""
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load GED model and tokenizer (same as training)
print(f"Loading GED model from HuggingFace: {ged_model_name}...")
self.ged_model, self.ged_tokenizer = self._load_ged_model(ged_model_name)
# Load T5 model and tokenizer from HuggingFace
print(f"Loading T5 model from HuggingFace: {model_name}...")
self.t5_tokenizer = T5Tokenizer.from_pretrained(model_name)
self.t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
self.t5_model.to(self.device)
# Create GED encoder (copy of T5 encoder)
self.ged_encoder = T5ForConditionalGeneration.from_pretrained(model_name).encoder
self.ged_encoder.to(self.device)
# Create gating mechanism
encoder_hidden_size = self.t5_model.config.d_model
self.gate = nn.Linear(2 * encoder_hidden_size, 1)
self.gate.to(self.device)
# Try to load GED components from HuggingFace
try:
print("Loading GED components...")
from huggingface_hub import hf_hub_download
ged_components_path = hf_hub_download(
repo_id=model_name,
filename="ged_components.pt",
cache_dir=None
)
ged_components = torch.load(ged_components_path, map_location=self.device)
self.ged_encoder.load_state_dict(ged_components["ged_encoder"])
self.gate.load_state_dict(ged_components["gate"])
print("GED components loaded successfully!")
except Exception as e:
print(f"Warning: Could not load GED components: {e}")
print("Using default initialization for GED encoder and gate.")
# Set to evaluation mode
self.t5_model.eval()
self.ged_encoder.eval()
self.gate.eval()
def _load_ged_model(self, model_name):
"""Load GED model and tokenizer from HuggingFace"""
tokenizer = ElectraTokenizer.from_pretrained(model_name)
model = ElectraForTokenClassification.from_pretrained(model_name)
model.to(self.device)
model.eval()
return model, tokenizer
def _get_ged_predictions(self, text):
"""Get GED predictions for input text - exact same as training preprocessing"""
inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device)
with torch.no_grad():
outputs = self.ged_model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
token_predictions = predictions[0].cpu().numpy().tolist()
tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
ged_tags = []
for token, pred in zip(tokens, token_predictions):
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
ged_tags.append(str(pred))
return " ".join(ged_tags), tokens, token_predictions
def _get_error_spans(self, text):
"""Extract error spans with simplified categories for display"""
ged_tags_str, tokens, predictions = self._get_ged_predictions(text)
error_spans = []
clean_tokens = []
for token, pred in zip(tokens, predictions):
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
clean_tokens.append(token)
if pred != 0: # 0 is correct, others are various error types
# Simplify the 11-tag system to basic categories for user display
if pred in [1, 2, 3, 4]: # Various replacement/substitution errors
error_type = "Grammar"
elif pred in [5, 6]: # Missing elements
error_type = "Missing"
elif pred in [7, 8]: # Unnecessary elements
error_type = "Unnecessary"
elif pred in [9, 10]: # Other error types
error_type = "Usage"
else:
error_type = "Error"
error_spans.append({
"token": token,
"type": error_type,
"position": len(clean_tokens) - 1
})
return error_spans
def _preprocess_inputs(self, text, max_length=128):
"""Preprocess input text exactly as during training"""
# Get GED predictions
ged_tags, _, _ = self._get_ged_predictions(text)
# Tokenize source text (same as training)
src_tokens = self.t5_tokenizer(
text,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
# Tokenize GED tags (same as training)
ged_tokens = self.t5_tokenizer(
ged_tags,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
return {
"input_ids": src_tokens.input_ids.to(self.device),
"attention_mask": src_tokens.attention_mask.to(self.device),
"ged_input_ids": ged_tokens.input_ids.to(self.device),
"ged_attention_mask": ged_tokens.attention_mask.to(self.device)
}
def _forward_with_ged(self, input_ids, attention_mask, ged_input_ids, ged_attention_mask, max_length=200):
"""
Forward pass with GED integration - replicates T5WithGED.forward() logic
"""
# Get source encoder outputs
src_encoder_outputs = self.t5_model.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get GED encoder outputs
ged_encoder_outputs = self.ged_encoder(
input_ids=ged_input_ids,
attention_mask=ged_attention_mask,
return_dict=True
)
# Get hidden states
src_hidden_states = src_encoder_outputs.last_hidden_state
ged_hidden_states = ged_encoder_outputs.last_hidden_state
# Combine hidden states (same as training)
min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1))
combined = torch.cat([
src_hidden_states[:, :min_len, :],
ged_hidden_states[:, :min_len, :]
], dim=2)
# Apply gating mechanism
gate_scores = torch.sigmoid(self.gate(combined))
combined_hidden = (
gate_scores * src_hidden_states[:, :min_len, :] +
(1 - gate_scores) * ged_hidden_states[:, :min_len, :]
)
# Update encoder outputs
src_encoder_outputs.last_hidden_state = combined_hidden
# Generate using T5 decoder
decoder_outputs = self.t5_model.generate(
encoder_outputs=src_encoder_outputs,
max_length=max_length,
do_sample=False,
num_beams=1
)
return decoder_outputs
def correct_text(self, text, max_length=200):
"""
Correct grammatical errors in input text
Args:
text: Input text to correct
max_length: Maximum length for generation
Returns:
Corrected text as string
"""
# Preprocess inputs exactly as training
inputs = self._preprocess_inputs(text)
# Generate correction using GED-enhanced model
with torch.no_grad():
generated_ids = self._forward_with_ged(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
ged_input_ids=inputs["ged_input_ids"],
ged_attention_mask=inputs["ged_attention_mask"],
max_length=max_length
)
# Decode output
corrected_text = self.t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return corrected_text
def analyze_text(self, text):
"""Enhanced analysis method for Gradio integration"""
if not text.strip():
return "Model not available or empty text", ""
try:
# Get corrected text
corrected_text = self.correct_text(text)
# Get error spans
error_spans = self._get_error_spans(text)
# Generate HTML output
html_output = self.generate_html_analysis(text, corrected_text, error_spans)
return corrected_text, html_output
except Exception as e:
return f"Error during analysis: {str(e)}", ""
def generate_html_analysis(self, original, corrected, error_spans):
"""Generate enhanced HTML analysis output"""
# Create highlighted original text
highlighted_original = original
if error_spans:
# Sort by position in reverse to avoid index shifting
sorted_spans = sorted(error_spans, key=lambda x: x['position'], reverse=True)
# Simple highlighting - in a more sophisticated version, you'd map token positions to character positions
for span in sorted_spans:
token = span['token']
error_type = span['type']
# Color coding for different error types
color_map = {
"Grammar": "#ffebee", # Light red
"Missing": "#e8f5e8", # Light green
"Unnecessary": "#fff3e0", # Light orange
"Usage": "#e3f2fd" # Light blue
}
color = color_map.get(error_type, "#f5f5f5")
# Simple token replacement (basic highlighting)
if token in highlighted_original:
highlighted_original = highlighted_original.replace(
token,
f"{token}",
1
)
html = f"""
Grammar Analysis Results
Original Text with Error Highlighting:
{highlighted_original}
Corrected Text:
{corrected}
Error Summary:
Found {len(error_spans)} potential issues
Grammar
Missing
Unnecessary
Usage
"""
return html
# Initialize SQLite database for storing submissions and exercises
def init_database():
conn = sqlite3.connect('language_app.db')
c = conn.cursor()
# Users table
c.execute('''CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT UNIQUE NOT NULL,
role TEXT NOT NULL,
password_hash TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)''')
# Tasks table
c.execute('''CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
description TEXT NOT NULL,
image_url TEXT,
creator_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)''')
# Submissions table
c.execute('''CREATE TABLE IF NOT EXISTS submissions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER,
student_name TEXT NOT NULL,
content TEXT NOT NULL,
analysis_result TEXT,
analysis_html TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)''')
# Exercises table
c.execute('''CREATE TABLE IF NOT EXISTS exercises (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
instructions TEXT NOT NULL,
sentences TEXT NOT NULL,
image_url TEXT,
submission_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)''')
# Exercise attempts table
c.execute('''CREATE TABLE IF NOT EXISTS exercise_attempts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
exercise_id INTEGER,
student_name TEXT NOT NULL,
responses TEXT NOT NULL,
score REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)''')
conn.commit()
conn.close()
# Initialize database and components
init_database()
print("Initializing enhanced grammar checker...")
grammar_checker = HuggingFaceT5GEDInference()
print("Grammar checker initialized successfully!")
# Gradio Interface Functions
def analyze_student_writing(text, student_name, task_title="General Writing Task"):
"""Analyze student writing and store in database"""
if not text.strip():
return "Please enter some text to analyze.", ""
if not student_name.strip():
return "Please enter your name.", ""
# Analyze text with enhanced model
corrected_text, html_analysis = grammar_checker.analyze_text(text)
# Store in database
conn = sqlite3.connect('language_app.db')
c = conn.cursor()
# Insert task if not exists
c.execute("INSERT OR IGNORE INTO tasks (title, description) VALUES (?, ?)",
(task_title, f"Writing task: {task_title}"))
c.execute("SELECT id FROM tasks WHERE title = ?", (task_title,))
task_id = c.fetchone()[0]
# Insert submission
analysis_data = {
"corrected_text": corrected_text,
"original_text": text,
"html_output": html_analysis
}
c.execute("""INSERT INTO submissions (task_id, student_name, content, analysis_result, analysis_html)
VALUES (?, ?, ?, ?, ?)""",
(task_id, student_name, text, json.dumps(analysis_data), html_analysis))
submission_id = c.lastrowid
conn.commit()
conn.close()
return corrected_text, html_analysis
def create_exercise_from_text(text, exercise_title="Grammar Exercise"):
"""Create an exercise from text with errors using enhanced analysis"""
if not text.strip():
return "Please enter text to create an exercise.", ""
# Analyze text to find sentences with errors
sentences = nltk.sent_tokenize(text)
exercise_sentences = []
for sentence in sentences:
corrected, _ = grammar_checker.analyze_text(sentence)
if sentence.strip() != corrected.strip(): # Has errors
exercise_sentences.append({
"original": sentence.strip(),
"corrected": corrected.strip()
})
if not exercise_sentences:
return "No errors found in the text. Cannot create exercise.", ""
# Store exercise in database
conn = sqlite3.connect('language_app.db')
c = conn.cursor()
c.execute("""INSERT INTO exercises (title, instructions, sentences)
VALUES (?, ?, ?)""",
(exercise_title,
"Correct the grammatical errors in the following sentences:",
json.dumps(exercise_sentences)))
exercise_id = c.lastrowid
conn.commit()
conn.close()
# Generate exercise HTML
exercise_html = f"""
{exercise_title}
Exercise ID: {exercise_id}
Instructions: Correct the grammatical errors in the following sentences:
"""
for i, sentence_data in enumerate(exercise_sentences, 1):
exercise_html += f"- {sentence_data['original']}
"
exercise_html += "
"
return f"Exercise created with {len(exercise_sentences)} sentences! Exercise ID: {exercise_id}", exercise_html
def attempt_exercise(exercise_id, student_responses, student_name):
"""Submit exercise attempt and get score using enhanced analysis"""
if not student_name.strip():
return "Please enter your name.", ""
try:
exercise_id = int(exercise_id)
except:
return "Please enter a valid exercise ID.", ""
# Get exercise from database
conn = sqlite3.connect('language_app.db')
c = conn.cursor()
c.execute("SELECT sentences FROM exercises WHERE id = ?", (exercise_id,))
result = c.fetchone()
if not result:
return "Exercise not found.", ""
exercise_sentences = json.loads(result[0])
# Parse student responses
responses = [r.strip() for r in student_responses.split('\n') if r.strip()]
if len(responses) != len(exercise_sentences):
return f"Please provide exactly {len(exercise_sentences)} responses (one per line).", ""
# Calculate score using enhanced analysis
correct_count = 0
feedback = []
for i, (sentence_data, response) in enumerate(zip(exercise_sentences, responses), 1):
correct_answer = sentence_data['corrected']
# Use the model to check if the response is correct
response_corrected, _ = grammar_checker.analyze_text(response)
is_correct = response_corrected.strip() == response.strip() # No further corrections needed
if is_correct:
correct_count += 1
feedback.append(f"✅ Sentence {i}: Excellent! No errors detected.")
else:
feedback.append(f"❌ Sentence {i}: Your answer: '{response}' | Suggested improvement: '{response_corrected}' | Expected: '{correct_answer}'")
score = (correct_count / len(exercise_sentences)) * 100
# Store attempt
attempt_data = {
"responses": responses,
"score": score,
"feedback": feedback
}
c.execute("""INSERT INTO exercise_attempts (exercise_id, student_name, responses, score)
VALUES (?, ?, ?, ?)""",
(exercise_id, student_name, json.dumps(attempt_data), score))
conn.commit()
conn.close()
feedback_html = f"""
Exercise Results
Score: {score:.1f}% ({correct_count}/{len(exercise_sentences)} correct)
{'
'.join(feedback)}
"""
return f"Score: {score:.1f}%", feedback_html
def get_student_progress(student_name):
"""Get student's submission and exercise history"""
if not student_name.strip():
return "Please enter a student name."
conn = sqlite3.connect('language_app.db')
c = conn.cursor()
# Get submissions
c.execute("""SELECT s.id, s.content, s.created_at, t.title
FROM submissions s JOIN tasks t ON s.task_id = t.id
WHERE s.student_name = ? ORDER BY s.created_at DESC""", (student_name,))
submissions = c.fetchall()
# Get exercise attempts
c.execute("""SELECT ea.score, ea.created_at, e.title
FROM exercise_attempts ea JOIN exercises e ON ea.exercise_id = e.id
WHERE ea.student_name = ? ORDER BY ea.created_at DESC""", (student_name,))
attempts = c.fetchall()
conn.close()
progress_html = f"""
Progress for {student_name}
Writing Submissions ({len(submissions)})
"""
for sub in submissions:
progress_html += f"- {sub[3]} - {sub[2][:16]} - {len(sub[1])} characters
"
progress_html += f"""
Exercise Attempts ({len(attempts)})
"""
for att in attempts:
progress_html += f"- {att[2]} - Score: {att[0]:.1f}% - {att[1][:16]}
"
progress_html += "
"
return progress_html
# Create Gradio Interface
with gr.Blocks(title="Language Learning App - Enhanced Grammar Checker", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎓 Language Learning Application")
gr.Markdown("### AI-Powered Grammar Checking and Exercise Generation")
gr.Markdown("*Now featuring advanced T5-GED neural network with enhanced error detection*")
with gr.Tabs():
# Student Writing Analysis Tab
with gr.TabItem("📝 Writing Analysis"):
gr.Markdown("## Submit Your Writing for Analysis")
with gr.Row():
with gr.Column():
student_name_input = gr.Textbox(label="Your Name", placeholder="Enter your name")
task_title_input = gr.Textbox(label="Assignment Title", value="General Writing Task")
writing_input = gr.Textbox(
label="Your Writing",
lines=8,
placeholder="Paste your writing here for grammar analysis..."
)
analyze_btn = gr.Button("Analyze Writing", variant="primary")
with gr.Column():
corrected_output = gr.Textbox(label="Corrected Text", lines=6)
analysis_output = gr.HTML(label="Detailed Analysis")
analyze_btn.click(
analyze_student_writing,
inputs=[writing_input, student_name_input, task_title_input],
outputs=[corrected_output, analysis_output]
)
# Exercise Creation Tab
with gr.TabItem("🏋️ Exercise Creation"):
gr.Markdown("## Create Grammar Exercises")
with gr.Row():
with gr.Column():
exercise_title_input = gr.Textbox(label="Exercise Title", value="Grammar Exercise")
exercise_text_input = gr.Textbox(
label="Text with Errors",
lines=6,
placeholder="Enter text containing grammatical errors to create an exercise..."
)
create_exercise_btn = gr.Button("Create Exercise", variant="primary")
with gr.Column():
exercise_result = gr.Textbox(label="Result")
exercise_display = gr.HTML(label="Generated Exercise")
create_exercise_btn.click(
create_exercise_from_text,
inputs=[exercise_text_input, exercise_title_input],
outputs=[exercise_result, exercise_display]
)
# Exercise Attempt Tab
with gr.TabItem("✏️ Exercise Practice"):
gr.Markdown("## Practice Grammar Exercises")
with gr.Row():
with gr.Column():
exercise_id_input = gr.Textbox(label="Exercise ID", placeholder="Enter exercise ID")
student_name_exercise = gr.Textbox(label="Your Name", placeholder="Enter your name")
responses_input = gr.Textbox(
label="Your Answers",
lines=6,
placeholder="Enter your corrected sentences (one per line)..."
)
submit_exercise_btn = gr.Button("Submit Answers", variant="primary")
with gr.Column():
score_output = gr.Textbox(label="Your Score")
feedback_output = gr.HTML(label="Detailed Feedback")
submit_exercise_btn.click(
attempt_exercise,
inputs=[exercise_id_input, responses_input, student_name_exercise],
outputs=[score_output, feedback_output]
)
# Progress Tracking Tab
with gr.TabItem("📊 Student Progress"):
gr.Markdown("## View Student Progress")
with gr.Row():
with gr.Column(scale=1):
progress_student_name = gr.Textbox(label="Student Name", placeholder="Enter student name")
get_progress_btn = gr.Button("Get Progress", variant="primary")
with gr.Column(scale=2):
progress_output = gr.HTML(label="Student Progress")
get_progress_btn.click(
get_student_progress,
inputs=[progress_student_name],
outputs=[progress_output]
)
gr.Markdown("""
---
### How to Use:
1. **Writing Analysis**: Submit your writing to get grammar corrections and detailed error analysis
2. **Exercise Creation**: Teachers can create exercises from text containing errors
3. **Exercise Practice**: Students can practice with generated exercises and get scored feedback
4. **Progress Tracking**: View student progress across submissions and exercises
*Powered by advanced T5-GED neural networks for enhanced grammar error detection and correction*
""")
if __name__ == "__main__":
app.launch(share=True)