import gradio as gr import sqlite3 import json import os from datetime import datetime import torch import nltk from transformers import ( T5Tokenizer, T5ForConditionalGeneration, ElectraTokenizer, ElectraForTokenClassification ) import torch.nn as nn from tqdm import tqdm # Download NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') class HuggingFaceT5GEDInference: def __init__(self, model_name="Zlovoblachko/REAlEC_2step_model_testing", ged_model_name="Zlovoblachko/11tag-electra-grammar-stage2", device=None): """ Initialize the inference class for T5-GED model from HuggingFace Args: model_name: HuggingFace model name/path for the T5-GED model ged_model_name: HuggingFace model name/path for the GED model device: Device to run inference on (cuda/cpu) """ self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load GED model and tokenizer (same as training) print(f"Loading GED model from HuggingFace: {ged_model_name}...") self.ged_model, self.ged_tokenizer = self._load_ged_model(ged_model_name) # Load T5 model and tokenizer from HuggingFace print(f"Loading T5 model from HuggingFace: {model_name}...") self.t5_tokenizer = T5Tokenizer.from_pretrained(model_name) self.t5_model = T5ForConditionalGeneration.from_pretrained(model_name) self.t5_model.to(self.device) # Create GED encoder (copy of T5 encoder) self.ged_encoder = T5ForConditionalGeneration.from_pretrained(model_name).encoder self.ged_encoder.to(self.device) # Create gating mechanism encoder_hidden_size = self.t5_model.config.d_model self.gate = nn.Linear(2 * encoder_hidden_size, 1) self.gate.to(self.device) # Try to load GED components from HuggingFace try: print("Loading GED components...") from huggingface_hub import hf_hub_download ged_components_path = hf_hub_download( repo_id=model_name, filename="ged_components.pt", cache_dir=None ) ged_components = torch.load(ged_components_path, map_location=self.device) self.ged_encoder.load_state_dict(ged_components["ged_encoder"]) self.gate.load_state_dict(ged_components["gate"]) print("GED components loaded successfully!") except Exception as e: print(f"Warning: Could not load GED components: {e}") print("Using default initialization for GED encoder and gate.") # Set to evaluation mode self.t5_model.eval() self.ged_encoder.eval() self.gate.eval() def _load_ged_model(self, model_name): """Load GED model and tokenizer from HuggingFace""" tokenizer = ElectraTokenizer.from_pretrained(model_name) model = ElectraForTokenClassification.from_pretrained(model_name) model.to(self.device) model.eval() return model, tokenizer def _get_ged_predictions(self, text): """Get GED predictions for input text - exact same as training preprocessing""" inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) with torch.no_grad(): outputs = self.ged_model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=2) token_predictions = predictions[0].cpu().numpy().tolist() tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) ged_tags = [] for token, pred in zip(tokens, token_predictions): if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: continue ged_tags.append(str(pred)) return " ".join(ged_tags), tokens, token_predictions def _get_error_spans(self, text): """Extract error spans with simplified categories for display""" ged_tags_str, tokens, predictions = self._get_ged_predictions(text) error_spans = [] clean_tokens = [] for token, pred in zip(tokens, predictions): if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: continue clean_tokens.append(token) if pred != 0: # 0 is correct, others are various error types # Simplify the 11-tag system to basic categories for user display if pred in [1, 2, 3, 4]: # Various replacement/substitution errors error_type = "Grammar" elif pred in [5, 6]: # Missing elements error_type = "Missing" elif pred in [7, 8]: # Unnecessary elements error_type = "Unnecessary" elif pred in [9, 10]: # Other error types error_type = "Usage" else: error_type = "Error" error_spans.append({ "token": token, "type": error_type, "position": len(clean_tokens) - 1 }) return error_spans def _preprocess_inputs(self, text, max_length=128): """Preprocess input text exactly as during training""" # Get GED predictions ged_tags, _, _ = self._get_ged_predictions(text) # Tokenize source text (same as training) src_tokens = self.t5_tokenizer( text, truncation=True, max_length=max_length, return_tensors="pt" ) # Tokenize GED tags (same as training) ged_tokens = self.t5_tokenizer( ged_tags, truncation=True, max_length=max_length, return_tensors="pt" ) return { "input_ids": src_tokens.input_ids.to(self.device), "attention_mask": src_tokens.attention_mask.to(self.device), "ged_input_ids": ged_tokens.input_ids.to(self.device), "ged_attention_mask": ged_tokens.attention_mask.to(self.device) } def _forward_with_ged(self, input_ids, attention_mask, ged_input_ids, ged_attention_mask, max_length=200): """ Forward pass with GED integration - replicates T5WithGED.forward() logic """ # Get source encoder outputs src_encoder_outputs = self.t5_model.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) # Get GED encoder outputs ged_encoder_outputs = self.ged_encoder( input_ids=ged_input_ids, attention_mask=ged_attention_mask, return_dict=True ) # Get hidden states src_hidden_states = src_encoder_outputs.last_hidden_state ged_hidden_states = ged_encoder_outputs.last_hidden_state # Combine hidden states (same as training) min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1)) combined = torch.cat([ src_hidden_states[:, :min_len, :], ged_hidden_states[:, :min_len, :] ], dim=2) # Apply gating mechanism gate_scores = torch.sigmoid(self.gate(combined)) combined_hidden = ( gate_scores * src_hidden_states[:, :min_len, :] + (1 - gate_scores) * ged_hidden_states[:, :min_len, :] ) # Update encoder outputs src_encoder_outputs.last_hidden_state = combined_hidden # Generate using T5 decoder decoder_outputs = self.t5_model.generate( encoder_outputs=src_encoder_outputs, max_length=max_length, do_sample=False, num_beams=1 ) return decoder_outputs def correct_text(self, text, max_length=200): """ Correct grammatical errors in input text Args: text: Input text to correct max_length: Maximum length for generation Returns: Corrected text as string """ # Preprocess inputs exactly as training inputs = self._preprocess_inputs(text) # Generate correction using GED-enhanced model with torch.no_grad(): generated_ids = self._forward_with_ged( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ged_input_ids=inputs["ged_input_ids"], ged_attention_mask=inputs["ged_attention_mask"], max_length=max_length ) # Decode output corrected_text = self.t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True) return corrected_text def analyze_text(self, text): """Enhanced analysis method for Gradio integration""" if not text.strip(): return "Model not available or empty text", "" try: # Get corrected text corrected_text = self.correct_text(text) # Get error spans error_spans = self._get_error_spans(text) # Generate HTML output html_output = self.generate_html_analysis(text, corrected_text, error_spans) return corrected_text, html_output except Exception as e: return f"Error during analysis: {str(e)}", "" def generate_html_analysis(self, original, corrected, error_spans): """Generate enhanced HTML analysis output""" # Create highlighted original text highlighted_original = original if error_spans: # Sort by position in reverse to avoid index shifting sorted_spans = sorted(error_spans, key=lambda x: x['position'], reverse=True) # Simple highlighting - in a more sophisticated version, you'd map token positions to character positions for span in sorted_spans: token = span['token'] error_type = span['type'] # Color coding for different error types color_map = { "Grammar": "#ffebee", # Light red "Missing": "#e8f5e8", # Light green "Unnecessary": "#fff3e0", # Light orange "Usage": "#e3f2fd" # Light blue } color = color_map.get(error_type, "#f5f5f5") # Simple token replacement (basic highlighting) if token in highlighted_original: highlighted_original = highlighted_original.replace( token, f"{token}", 1 ) html = f"""

Grammar Analysis Results

Original Text with Error Highlighting:

{highlighted_original}

Corrected Text:

{corrected}

Error Summary:

Found {len(error_spans)} potential issues

Grammar Missing Unnecessary Usage
""" return html # Initialize SQLite database for storing submissions and exercises def init_database(): conn = sqlite3.connect('language_app.db') c = conn.cursor() # Users table c.execute('''CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, email TEXT UNIQUE NOT NULL, role TEXT NOT NULL, password_hash TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )''') # Tasks table c.execute('''CREATE TABLE IF NOT EXISTS tasks ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, description TEXT NOT NULL, image_url TEXT, creator_id INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )''') # Submissions table c.execute('''CREATE TABLE IF NOT EXISTS submissions ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id INTEGER, student_name TEXT NOT NULL, content TEXT NOT NULL, analysis_result TEXT, analysis_html TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )''') # Exercises table c.execute('''CREATE TABLE IF NOT EXISTS exercises ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, instructions TEXT NOT NULL, sentences TEXT NOT NULL, image_url TEXT, submission_id INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )''') # Exercise attempts table c.execute('''CREATE TABLE IF NOT EXISTS exercise_attempts ( id INTEGER PRIMARY KEY AUTOINCREMENT, exercise_id INTEGER, student_name TEXT NOT NULL, responses TEXT NOT NULL, score REAL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )''') conn.commit() conn.close() # Initialize database and components init_database() print("Initializing enhanced grammar checker...") grammar_checker = HuggingFaceT5GEDInference() print("Grammar checker initialized successfully!") # Gradio Interface Functions def analyze_student_writing(text, student_name, task_title="General Writing Task"): """Analyze student writing and store in database""" if not text.strip(): return "Please enter some text to analyze.", "" if not student_name.strip(): return "Please enter your name.", "" # Analyze text with enhanced model corrected_text, html_analysis = grammar_checker.analyze_text(text) # Store in database conn = sqlite3.connect('language_app.db') c = conn.cursor() # Insert task if not exists c.execute("INSERT OR IGNORE INTO tasks (title, description) VALUES (?, ?)", (task_title, f"Writing task: {task_title}")) c.execute("SELECT id FROM tasks WHERE title = ?", (task_title,)) task_id = c.fetchone()[0] # Insert submission analysis_data = { "corrected_text": corrected_text, "original_text": text, "html_output": html_analysis } c.execute("""INSERT INTO submissions (task_id, student_name, content, analysis_result, analysis_html) VALUES (?, ?, ?, ?, ?)""", (task_id, student_name, text, json.dumps(analysis_data), html_analysis)) submission_id = c.lastrowid conn.commit() conn.close() return corrected_text, html_analysis def create_exercise_from_text(text, exercise_title="Grammar Exercise"): """Create an exercise from text with errors using enhanced analysis""" if not text.strip(): return "Please enter text to create an exercise.", "" # Analyze text to find sentences with errors sentences = nltk.sent_tokenize(text) exercise_sentences = [] for sentence in sentences: corrected, _ = grammar_checker.analyze_text(sentence) if sentence.strip() != corrected.strip(): # Has errors exercise_sentences.append({ "original": sentence.strip(), "corrected": corrected.strip() }) if not exercise_sentences: return "No errors found in the text. Cannot create exercise.", "" # Store exercise in database conn = sqlite3.connect('language_app.db') c = conn.cursor() c.execute("""INSERT INTO exercises (title, instructions, sentences) VALUES (?, ?, ?)""", (exercise_title, "Correct the grammatical errors in the following sentences:", json.dumps(exercise_sentences))) exercise_id = c.lastrowid conn.commit() conn.close() # Generate exercise HTML exercise_html = f"""

{exercise_title}

Exercise ID: {exercise_id}

Instructions: Correct the grammatical errors in the following sentences:

    """ for i, sentence_data in enumerate(exercise_sentences, 1): exercise_html += f"
  1. {sentence_data['original']}
  2. " exercise_html += "
" return f"Exercise created with {len(exercise_sentences)} sentences! Exercise ID: {exercise_id}", exercise_html def attempt_exercise(exercise_id, student_responses, student_name): """Submit exercise attempt and get score using enhanced analysis""" if not student_name.strip(): return "Please enter your name.", "" try: exercise_id = int(exercise_id) except: return "Please enter a valid exercise ID.", "" # Get exercise from database conn = sqlite3.connect('language_app.db') c = conn.cursor() c.execute("SELECT sentences FROM exercises WHERE id = ?", (exercise_id,)) result = c.fetchone() if not result: return "Exercise not found.", "" exercise_sentences = json.loads(result[0]) # Parse student responses responses = [r.strip() for r in student_responses.split('\n') if r.strip()] if len(responses) != len(exercise_sentences): return f"Please provide exactly {len(exercise_sentences)} responses (one per line).", "" # Calculate score using enhanced analysis correct_count = 0 feedback = [] for i, (sentence_data, response) in enumerate(zip(exercise_sentences, responses), 1): correct_answer = sentence_data['corrected'] # Use the model to check if the response is correct response_corrected, _ = grammar_checker.analyze_text(response) is_correct = response_corrected.strip() == response.strip() # No further corrections needed if is_correct: correct_count += 1 feedback.append(f"✅ Sentence {i}: Excellent! No errors detected.") else: feedback.append(f"❌ Sentence {i}: Your answer: '{response}' | Suggested improvement: '{response_corrected}' | Expected: '{correct_answer}'") score = (correct_count / len(exercise_sentences)) * 100 # Store attempt attempt_data = { "responses": responses, "score": score, "feedback": feedback } c.execute("""INSERT INTO exercise_attempts (exercise_id, student_name, responses, score) VALUES (?, ?, ?, ?)""", (exercise_id, student_name, json.dumps(attempt_data), score)) conn.commit() conn.close() feedback_html = f"""

Exercise Results

Score: {score:.1f}% ({correct_count}/{len(exercise_sentences)} correct)

{'
'.join(feedback)}
""" return f"Score: {score:.1f}%", feedback_html def get_student_progress(student_name): """Get student's submission and exercise history""" if not student_name.strip(): return "Please enter a student name." conn = sqlite3.connect('language_app.db') c = conn.cursor() # Get submissions c.execute("""SELECT s.id, s.content, s.created_at, t.title FROM submissions s JOIN tasks t ON s.task_id = t.id WHERE s.student_name = ? ORDER BY s.created_at DESC""", (student_name,)) submissions = c.fetchall() # Get exercise attempts c.execute("""SELECT ea.score, ea.created_at, e.title FROM exercise_attempts ea JOIN exercises e ON ea.exercise_id = e.id WHERE ea.student_name = ? ORDER BY ea.created_at DESC""", (student_name,)) attempts = c.fetchall() conn.close() progress_html = f"""

Progress for {student_name}

Writing Submissions ({len(submissions)})

Exercise Attempts ({len(attempts)})

" return progress_html # Create Gradio Interface with gr.Blocks(title="Language Learning App - Enhanced Grammar Checker", theme=gr.themes.Soft()) as app: gr.Markdown("# 🎓 Language Learning Application") gr.Markdown("### AI-Powered Grammar Checking and Exercise Generation") gr.Markdown("*Now featuring advanced T5-GED neural network with enhanced error detection*") with gr.Tabs(): # Student Writing Analysis Tab with gr.TabItem("📝 Writing Analysis"): gr.Markdown("## Submit Your Writing for Analysis") with gr.Row(): with gr.Column(): student_name_input = gr.Textbox(label="Your Name", placeholder="Enter your name") task_title_input = gr.Textbox(label="Assignment Title", value="General Writing Task") writing_input = gr.Textbox( label="Your Writing", lines=8, placeholder="Paste your writing here for grammar analysis..." ) analyze_btn = gr.Button("Analyze Writing", variant="primary") with gr.Column(): corrected_output = gr.Textbox(label="Corrected Text", lines=6) analysis_output = gr.HTML(label="Detailed Analysis") analyze_btn.click( analyze_student_writing, inputs=[writing_input, student_name_input, task_title_input], outputs=[corrected_output, analysis_output] ) # Exercise Creation Tab with gr.TabItem("🏋️ Exercise Creation"): gr.Markdown("## Create Grammar Exercises") with gr.Row(): with gr.Column(): exercise_title_input = gr.Textbox(label="Exercise Title", value="Grammar Exercise") exercise_text_input = gr.Textbox( label="Text with Errors", lines=6, placeholder="Enter text containing grammatical errors to create an exercise..." ) create_exercise_btn = gr.Button("Create Exercise", variant="primary") with gr.Column(): exercise_result = gr.Textbox(label="Result") exercise_display = gr.HTML(label="Generated Exercise") create_exercise_btn.click( create_exercise_from_text, inputs=[exercise_text_input, exercise_title_input], outputs=[exercise_result, exercise_display] ) # Exercise Attempt Tab with gr.TabItem("✏️ Exercise Practice"): gr.Markdown("## Practice Grammar Exercises") with gr.Row(): with gr.Column(): exercise_id_input = gr.Textbox(label="Exercise ID", placeholder="Enter exercise ID") student_name_exercise = gr.Textbox(label="Your Name", placeholder="Enter your name") responses_input = gr.Textbox( label="Your Answers", lines=6, placeholder="Enter your corrected sentences (one per line)..." ) submit_exercise_btn = gr.Button("Submit Answers", variant="primary") with gr.Column(): score_output = gr.Textbox(label="Your Score") feedback_output = gr.HTML(label="Detailed Feedback") submit_exercise_btn.click( attempt_exercise, inputs=[exercise_id_input, responses_input, student_name_exercise], outputs=[score_output, feedback_output] ) # Progress Tracking Tab with gr.TabItem("📊 Student Progress"): gr.Markdown("## View Student Progress") with gr.Row(): with gr.Column(scale=1): progress_student_name = gr.Textbox(label="Student Name", placeholder="Enter student name") get_progress_btn = gr.Button("Get Progress", variant="primary") with gr.Column(scale=2): progress_output = gr.HTML(label="Student Progress") get_progress_btn.click( get_student_progress, inputs=[progress_student_name], outputs=[progress_output] ) gr.Markdown(""" --- ### How to Use: 1. **Writing Analysis**: Submit your writing to get grammar corrections and detailed error analysis 2. **Exercise Creation**: Teachers can create exercises from text containing errors 3. **Exercise Practice**: Students can practice with generated exercises and get scored feedback 4. **Progress Tracking**: View student progress across submissions and exercises *Powered by advanced T5-GED neural networks for enhanced grammar error detection and correction* """) if __name__ == "__main__": app.launch(share=True)