Spaces:
Running
Running
import gradio as gr | |
import sqlite3 | |
import json | |
import os | |
from datetime import datetime | |
import torch | |
import nltk | |
from transformers import ( | |
T5Tokenizer, | |
T5ForConditionalGeneration, | |
ElectraTokenizer, | |
ElectraForTokenClassification | |
) | |
import torch.nn as nn | |
from tqdm import tqdm | |
# Download NLTK data | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
class HuggingFaceT5GEDInference: | |
def __init__(self, model_name="Zlovoblachko/REAlEC_2step_model_testing", | |
ged_model_name="Zlovoblachko/11tag-electra-grammar-stage2", device=None): | |
""" | |
Initialize the inference class for T5-GED model from HuggingFace | |
Args: | |
model_name: HuggingFace model name/path for the T5-GED model | |
ged_model_name: HuggingFace model name/path for the GED model | |
device: Device to run inference on (cuda/cpu) | |
""" | |
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load GED model and tokenizer (same as training) | |
print(f"Loading GED model from HuggingFace: {ged_model_name}...") | |
self.ged_model, self.ged_tokenizer = self._load_ged_model(ged_model_name) | |
# Load T5 model and tokenizer from HuggingFace | |
print(f"Loading T5 model from HuggingFace: {model_name}...") | |
self.t5_tokenizer = T5Tokenizer.from_pretrained(model_name) | |
self.t5_model = T5ForConditionalGeneration.from_pretrained(model_name) | |
self.t5_model.to(self.device) | |
# Create GED encoder (copy of T5 encoder) | |
self.ged_encoder = T5ForConditionalGeneration.from_pretrained(model_name).encoder | |
self.ged_encoder.to(self.device) | |
# Create gating mechanism | |
encoder_hidden_size = self.t5_model.config.d_model | |
self.gate = nn.Linear(2 * encoder_hidden_size, 1) | |
self.gate.to(self.device) | |
# Try to load GED components from HuggingFace | |
try: | |
print("Loading GED components...") | |
from huggingface_hub import hf_hub_download | |
ged_components_path = hf_hub_download( | |
repo_id=model_name, | |
filename="ged_components.pt", | |
cache_dir=None | |
) | |
ged_components = torch.load(ged_components_path, map_location=self.device) | |
self.ged_encoder.load_state_dict(ged_components["ged_encoder"]) | |
self.gate.load_state_dict(ged_components["gate"]) | |
print("GED components loaded successfully!") | |
except Exception as e: | |
print(f"Warning: Could not load GED components: {e}") | |
print("Using default initialization for GED encoder and gate.") | |
# Set to evaluation mode | |
self.t5_model.eval() | |
self.ged_encoder.eval() | |
self.gate.eval() | |
def _load_ged_model(self, model_name): | |
"""Load GED model and tokenizer from HuggingFace""" | |
tokenizer = ElectraTokenizer.from_pretrained(model_name) | |
model = ElectraForTokenClassification.from_pretrained(model_name) | |
model.to(self.device) | |
model.eval() | |
return model, tokenizer | |
def _get_ged_predictions(self, text): | |
"""Get GED predictions for input text - exact same as training preprocessing""" | |
inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) | |
with torch.no_grad(): | |
outputs = self.ged_model(**inputs) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=2) | |
token_predictions = predictions[0].cpu().numpy().tolist() | |
tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
ged_tags = [] | |
for token, pred in zip(tokens, token_predictions): | |
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
ged_tags.append(str(pred)) | |
return " ".join(ged_tags), tokens, token_predictions | |
def _get_error_spans(self, text): | |
"""Extract error spans with simplified categories for display""" | |
ged_tags_str, tokens, predictions = self._get_ged_predictions(text) | |
error_spans = [] | |
clean_tokens = [] | |
for token, pred in zip(tokens, predictions): | |
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
clean_tokens.append(token) | |
if pred != 0: # 0 is correct, others are various error types | |
# Simplify the 11-tag system to basic categories for user display | |
if pred in [1, 2, 3, 4]: # Various replacement/substitution errors | |
error_type = "Grammar" | |
elif pred in [5, 6]: # Missing elements | |
error_type = "Missing" | |
elif pred in [7, 8]: # Unnecessary elements | |
error_type = "Unnecessary" | |
elif pred in [9, 10]: # Other error types | |
error_type = "Usage" | |
else: | |
error_type = "Error" | |
error_spans.append({ | |
"token": token, | |
"type": error_type, | |
"position": len(clean_tokens) - 1 | |
}) | |
return error_spans | |
def _get_error_spans_detailed(self, text): | |
"""Extract error spans with detailed second_level_tag categories""" | |
ged_tags_str, tokens, predictions = self._get_ged_predictions(text) | |
error_spans = [] | |
error_types = [] | |
clean_tokens = [] | |
# Correct id2label mapping | |
id2label = { | |
0: "correct", | |
1: "ORTH", | |
2: "FORM", | |
3: "MORPH", | |
4: "DET", | |
5: "POS", | |
6: "VERB", | |
7: "NUM", | |
8: "WORD", | |
9: "PUNCT", | |
10: "RED", | |
11: "MULTIWORD", | |
12: "SPELL" | |
} | |
for token, pred in zip(tokens, predictions): | |
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
clean_tokens.append(token) | |
if pred != 0: # 0 is correct, others are various error types | |
error_type = id2label.get(pred, "OTHER") | |
error_types.append(error_type) | |
error_spans.append({ | |
"token": token, | |
"type": error_type, | |
"position": len(clean_tokens) - 1 | |
}) | |
return error_spans, list(set(error_types)) | |
def _preprocess_inputs(self, text, max_length=128): | |
"""Preprocess input text exactly as during training""" | |
# Get GED predictions | |
ged_tags, _, _ = self._get_ged_predictions(text) | |
# Tokenize source text (same as training) | |
src_tokens = self.t5_tokenizer( | |
text, | |
truncation=True, | |
max_length=max_length, | |
return_tensors="pt" | |
) | |
# Tokenize GED tags (same as training) | |
ged_tokens = self.t5_tokenizer( | |
ged_tags, | |
truncation=True, | |
max_length=max_length, | |
return_tensors="pt" | |
) | |
return { | |
"input_ids": src_tokens.input_ids.to(self.device), | |
"attention_mask": src_tokens.attention_mask.to(self.device), | |
"ged_input_ids": ged_tokens.input_ids.to(self.device), | |
"ged_attention_mask": ged_tokens.attention_mask.to(self.device) | |
} | |
def _forward_with_ged(self, input_ids, attention_mask, ged_input_ids, ged_attention_mask, max_length=200): | |
""" | |
Forward pass with GED integration - replicates T5WithGED.forward() logic | |
""" | |
# Get source encoder outputs | |
src_encoder_outputs = self.t5_model.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True | |
) | |
# Get GED encoder outputs | |
ged_encoder_outputs = self.ged_encoder( | |
input_ids=ged_input_ids, | |
attention_mask=ged_attention_mask, | |
return_dict=True | |
) | |
# Get hidden states | |
src_hidden_states = src_encoder_outputs.last_hidden_state | |
ged_hidden_states = ged_encoder_outputs.last_hidden_state | |
# Combine hidden states (same as training) | |
min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1)) | |
combined = torch.cat([ | |
src_hidden_states[:, :min_len, :], | |
ged_hidden_states[:, :min_len, :] | |
], dim=2) | |
# Apply gating mechanism | |
gate_scores = torch.sigmoid(self.gate(combined)) | |
combined_hidden = ( | |
gate_scores * src_hidden_states[:, :min_len, :] + | |
(1 - gate_scores) * ged_hidden_states[:, :min_len, :] | |
) | |
# Update encoder outputs | |
src_encoder_outputs.last_hidden_state = combined_hidden | |
# Generate using T5 decoder | |
decoder_outputs = self.t5_model.generate( | |
encoder_outputs=src_encoder_outputs, | |
max_length=max_length, | |
do_sample=False, | |
num_beams=1 | |
) | |
return decoder_outputs | |
def correct_text(self, text, max_length=200): | |
""" | |
Correct grammatical errors in input text | |
Args: | |
text: Input text to correct | |
max_length: Maximum length for generation | |
Returns: | |
Corrected text as string | |
""" | |
# Preprocess inputs exactly as training | |
inputs = self._preprocess_inputs(text) | |
# Generate correction using GED-enhanced model | |
with torch.no_grad(): | |
generated_ids = self._forward_with_ged( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
ged_input_ids=inputs["ged_input_ids"], | |
ged_attention_mask=inputs["ged_attention_mask"], | |
max_length=max_length | |
) | |
# Decode output | |
corrected_text = self.t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return corrected_text | |
def analyze_text(self, text): | |
"""Enhanced analysis method for Gradio integration""" | |
if not text.strip(): | |
return "Model not available or empty text", "" | |
try: | |
# Get corrected text | |
corrected_text = self.correct_text(text) | |
# Get error spans (use the original method for display) | |
error_spans = self._get_error_spans(text) | |
# Generate HTML output | |
html_output = self.generate_html_analysis(text, corrected_text, error_spans) | |
return corrected_text, html_output | |
except Exception as e: | |
return f"Error during analysis: {str(e)}", "" | |
def generate_html_analysis(self, original, corrected, error_spans): | |
"""Generate enhanced HTML analysis output""" | |
# Create highlighted original text | |
highlighted_original = original | |
if error_spans: | |
# Sort by position in reverse to avoid index shifting | |
sorted_spans = sorted(error_spans, key=lambda x: x['position'], reverse=True) | |
# Simple highlighting - in a more sophisticated version, you'd map token positions to character positions | |
for span in sorted_spans: | |
token = span['token'] | |
error_type = span['type'] | |
# Color coding for different error types | |
color_map = { | |
"Grammar": "#ffebee", # Light red | |
"Missing": "#e8f5e8", # Light green | |
"Unnecessary": "#fff3e0", # Light orange | |
"Usage": "#e3f2fd" # Light blue | |
} | |
color = color_map.get(error_type, "#f5f5f5") | |
# Simple token replacement (basic highlighting) | |
if token in highlighted_original: | |
highlighted_original = highlighted_original.replace( | |
token, | |
f"<span style='background-color: {color}; padding: 1px 3px; border-radius: 3px; margin: 0 1px;' title='{error_type}'>{token}</span>", | |
1 | |
) | |
html = f""" | |
<div style='font-family: Arial, sans-serif; line-height: 1.6; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;'> | |
<h3 style='color: #333; margin-top: 0;'>Grammar Analysis Results</h3> | |
<div style='margin: 15px 0;'> | |
<h4 style='color: #555;'>Original Text with Error Highlighting:</h4> | |
<div style='padding: 10px; background-color: #fff; border: 1px solid #ddd; border-radius: 4px;'>{highlighted_original}</div> | |
</div> | |
<div style='margin: 15px 0;'> | |
<h4 style='color: #28a745;'>Corrected Text:</h4> | |
<p style='padding: 10px; background-color: #d4edda; border: 1px solid #c3e6cb; border-radius: 4px;'>{corrected}</p> | |
</div> | |
<div style='margin: 15px 0;'> | |
<h4 style='color: #333;'>Error Summary:</h4> | |
<p style='color: #666;'>Found {len(error_spans)} potential issues</p> | |
<div style='margin-top: 10px;'> | |
<span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #ffebee; border-radius: 12px; font-size: 12px;'>Grammar</span> | |
<span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #e8f5e8; border-radius: 12px; font-size: 12px;'>Missing</span> | |
<span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #fff3e0; border-radius: 12px; font-size: 12px;'>Unnecessary</span> | |
<span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #e3f2fd; border-radius: 12px; font-size: 12px;'>Usage</span> | |
</div> | |
</div> | |
</div> | |
""" | |
return html | |
def clear_and_reload_database(): | |
"""Clear and reload the sentence database""" | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Clear existing data | |
c.execute("DELETE FROM sentence_database") | |
conn.commit() | |
print("Cleared existing sentence database") | |
conn.close() | |
# Reload | |
load_sentence_database() | |
# Initialize SQLite database for storing submissions and exercises | |
def init_database(): | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Users table | |
c.execute('''CREATE TABLE IF NOT EXISTS users ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
username TEXT UNIQUE NOT NULL, | |
email TEXT UNIQUE NOT NULL, | |
role TEXT NOT NULL, | |
password_hash TEXT NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Tasks table | |
c.execute('''CREATE TABLE IF NOT EXISTS tasks ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
title TEXT NOT NULL, | |
description TEXT NOT NULL, | |
image_url TEXT, | |
creator_id INTEGER, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Submissions table | |
c.execute('''CREATE TABLE IF NOT EXISTS submissions ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
task_id INTEGER, | |
student_name TEXT NOT NULL, | |
content TEXT NOT NULL, | |
analysis_result TEXT, | |
analysis_html TEXT, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Exercises table | |
c.execute('''CREATE TABLE IF NOT EXISTS exercises ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
title TEXT NOT NULL, | |
instructions TEXT NOT NULL, | |
sentences TEXT NOT NULL, | |
image_url TEXT, | |
submission_id INTEGER, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Exercise attempts table | |
c.execute('''CREATE TABLE IF NOT EXISTS exercise_attempts ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
exercise_id INTEGER, | |
student_name TEXT NOT NULL, | |
responses TEXT NOT NULL, | |
score REAL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Sentence database table - ADD THIS | |
c.execute('''CREATE TABLE IF NOT EXISTS sentence_database ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
text TEXT NOT NULL, | |
tags TEXT NOT NULL, | |
error_types TEXT NOT NULL | |
)''') | |
conn.commit() | |
conn.close() | |
def load_sentence_database(jsonl_file_path='sentencewise_full.jsonl'): | |
"""Load sentence database from JSONL file""" | |
print(f"Debug: Attempting to load from: {jsonl_file_path}") | |
print(f"Debug: Current working directory: {os.getcwd()}") | |
print(f"Debug: File exists: {os.path.exists(jsonl_file_path)}") | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Create sentence database table | |
c.execute('''CREATE TABLE IF NOT EXISTS sentence_database ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
text TEXT NOT NULL, | |
tags TEXT NOT NULL, | |
error_types TEXT NOT NULL | |
)''') | |
# Check if data already loaded | |
c.execute("SELECT COUNT(*) FROM sentence_database") | |
current_count = c.fetchone()[0] | |
if current_count > 0: | |
print(f"Sentence database already loaded with {current_count} sentences") | |
conn.close() | |
return | |
# Load JSONL file | |
try: | |
print(f"Debug: Opening file {jsonl_file_path}") | |
with open(jsonl_file_path, 'r', encoding='utf-8') as f: | |
lines_processed = 0 | |
for line_num, line in enumerate(f, 1): | |
try: | |
line = line.strip() | |
if not line: # Skip empty lines | |
continue | |
data = json.loads(line) | |
text = data.get('text', '') | |
tags = data.get('tags', []) | |
if not text or not tags: | |
print(f"Debug: Skipping line {line_num} - missing text or tags") | |
continue | |
# Extract second_level_tag error types | |
error_types = [] | |
for tag in tags: | |
second_level = tag.get('second_level_tag', '') | |
if second_level: | |
error_types.append(second_level) | |
error_types = list(set(error_types)) # Remove duplicates | |
# Debug: Print first few entries | |
if line_num <= 3: | |
print(f"Debug line {line_num}: text='{text[:50]}...', error_types={error_types}") | |
print(f"Debug: Raw tags for line {line_num}: {tags}") | |
if error_types: # Only insert if we have error types | |
c.execute("""INSERT INTO sentence_database (text, tags, error_types) | |
VALUES (?, ?, ?)""", | |
(text, json.dumps(tags), json.dumps(error_types))) | |
lines_processed += 1 | |
if line_num % 1000 == 0: | |
print(f"Processed {line_num} lines, inserted {lines_processed} sentences...") | |
except json.JSONDecodeError as e: | |
print(f"JSON decode error on line {line_num}: {e}") | |
print(f"Line content: {line[:100]}...") | |
continue | |
except Exception as e: | |
print(f"Error processing line {line_num}: {e}") | |
continue | |
conn.commit() | |
print(f"Successfully loaded sentence database with {lines_processed} sentences from {line_num} total lines") | |
except FileNotFoundError: | |
print(f"Error: {jsonl_file_path} not found in {os.getcwd()}") | |
print("Available files:") | |
try: | |
files = os.listdir('.') | |
for f in files: | |
if f.endswith('.jsonl') or f.endswith('.json'): | |
print(f" - {f}") | |
except: | |
print(" Could not list files") | |
except Exception as e: | |
print(f"Error loading sentence database: {e}") | |
conn.close() | |
def find_similar_sentences(error_types, limit=5): | |
"""Find sentences with similar error types from database""" | |
if not error_types: | |
return [] | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Build query to find sentences with matching error types | |
similar_sentences = [] | |
for error_type in error_types: | |
c.execute("""SELECT text, tags FROM sentence_database | |
WHERE error_types LIKE ? | |
ORDER BY RANDOM() | |
LIMIT ?""", (f'%"{error_type}"%', limit)) | |
results = c.fetchall() | |
for text, tags_json in results: | |
similar_sentences.append({ | |
'text': text, | |
'tags': json.loads(tags_json) | |
}) | |
conn.close() | |
# Remove duplicates and limit to requested number | |
seen_texts = set() | |
unique_sentences = [] | |
for sentence in similar_sentences: | |
if sentence['text'] not in seen_texts: | |
seen_texts.add(sentence['text']) | |
unique_sentences.append(sentence) | |
if len(unique_sentences) >= limit: | |
break | |
return unique_sentences | |
# Initialize database and components | |
init_database() | |
print("Clearing and loading sentence database...") | |
clear_and_reload_database() | |
print("Initializing enhanced grammar checker...") | |
grammar_checker = HuggingFaceT5GEDInference() | |
print("Grammar checker initialized successfully!") | |
# Gradio Interface Functions | |
def analyze_student_writing(text, student_name, task_title="General Writing Task"): | |
"""Analyze student writing and store in database""" | |
if not text.strip(): | |
return "Please enter some text to analyze.", "" | |
if not student_name.strip(): | |
return "Please enter your name.", "" | |
# Analyze text with enhanced model | |
corrected_text, html_analysis = grammar_checker.analyze_text(text) | |
# Store in database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Insert task if not exists | |
c.execute("INSERT OR IGNORE INTO tasks (title, description) VALUES (?, ?)", | |
(task_title, f"Writing task: {task_title}")) | |
c.execute("SELECT id FROM tasks WHERE title = ?", (task_title,)) | |
task_id = c.fetchone()[0] | |
# Insert submission | |
analysis_data = { | |
"corrected_text": corrected_text, | |
"original_text": text, | |
"html_output": html_analysis | |
} | |
c.execute("""INSERT INTO submissions (task_id, student_name, content, analysis_result, analysis_html) | |
VALUES (?, ?, ?, ?, ?)""", | |
(task_id, student_name, text, json.dumps(analysis_data), html_analysis)) | |
submission_id = c.lastrowid | |
conn.commit() | |
conn.close() | |
return corrected_text, html_analysis | |
def create_exercise_from_text(text, exercise_title="Grammar Exercise"): | |
"""Create an exercise from text with errors using sentence database""" | |
if not text.strip(): | |
return "Please enter text to create an exercise.", "" | |
# Analyze text to extract error types | |
sentences = nltk.sent_tokenize(text) | |
exercise_sentences = [] | |
all_error_types = [] | |
for sentence in sentences: | |
# Get detailed error analysis | |
error_spans, error_types = grammar_checker._get_error_spans_detailed(sentence) | |
if error_types: # Has errors | |
corrected, _ = grammar_checker.analyze_text(sentence) | |
exercise_sentences.append({ | |
"original": sentence.strip(), | |
"corrected": corrected.strip(), | |
"error_types": error_types | |
}) | |
all_error_types.extend(error_types) | |
if not exercise_sentences: | |
return "No errors found in the text. Cannot create exercise.", "" | |
# Find similar sentences from database | |
unique_error_types = list(set(all_error_types)) | |
similar_sentences = find_similar_sentences(unique_error_types, limit=5) | |
# Combine original sentences with similar ones from database | |
all_exercise_sentences = exercise_sentences.copy() | |
for similar in similar_sentences: | |
# Get corrected version of similar sentence | |
corrected, _ = grammar_checker.analyze_text(similar['text']) | |
all_exercise_sentences.append({ | |
"original": similar['text'], | |
"corrected": corrected, | |
"error_types": [tag.get('second_level_tag', '') for tag in similar['tags']] | |
}) | |
# Store exercise in database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
c.execute("""INSERT INTO exercises (title, instructions, sentences) | |
VALUES (?, ?, ?)""", | |
(exercise_title, | |
"Correct the grammatical errors in the following sentences:", | |
json.dumps(all_exercise_sentences))) | |
exercise_id = c.lastrowid | |
conn.commit() | |
conn.close() | |
# Generate exercise HTML | |
exercise_html = f""" | |
<div style='font-family: Arial, sans-serif; padding: 20px; border: 1px solid #ddd; border-radius: 8px;'> | |
<h3>{exercise_title}</h3> | |
<p><strong>Exercise ID: {exercise_id}</strong></p> | |
<p><strong>Instructions:</strong> Correct the grammatical errors in the following sentences:</p> | |
<p><em>Error types found: {', '.join(unique_error_types)}</em></p> | |
<ol> | |
""" | |
for i, sentence_data in enumerate(all_exercise_sentences, 1): | |
error_info = f" (Error types: {', '.join(sentence_data.get('error_types', []))})" if sentence_data.get('error_types') else "" | |
exercise_html += f"<li style='margin: 10px 0; padding: 10px; background-color: #f8f9fa; border-radius: 4px;'>{sentence_data['original']}{error_info}</li>" | |
exercise_html += "</ol></div>" | |
return f"Exercise created with {len(all_exercise_sentences)} sentences ({len(exercise_sentences)} original + {len(similar_sentences)} from database)! Exercise ID: {exercise_id}", exercise_html | |
def attempt_exercise(exercise_id, student_responses, student_name): | |
"""Submit exercise attempt and get score using enhanced analysis""" | |
if not student_name.strip(): | |
return "Please enter your name.", "" | |
try: | |
exercise_id = int(exercise_id) | |
except: | |
return "Please enter a valid exercise ID.", "" | |
# Get exercise from database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
c.execute("SELECT sentences FROM exercises WHERE id = ?", (exercise_id,)) | |
result = c.fetchone() | |
if not result: | |
return "Exercise not found.", "" | |
exercise_sentences = json.loads(result[0]) | |
# Parse student responses | |
responses = [r.strip() for r in student_responses.split('\n') if r.strip()] | |
if len(responses) != len(exercise_sentences): | |
return f"Please provide exactly {len(exercise_sentences)} responses (one per line).", "" | |
# Calculate score using enhanced analysis | |
correct_count = 0 | |
detailed_results = [] | |
for i, (sentence_data, response) in enumerate(zip(exercise_sentences, responses), 1): | |
original = sentence_data['original'] | |
expected = sentence_data['corrected'] | |
# Use the model to check if the response is correct | |
response_corrected, response_analysis = grammar_checker.analyze_text(response) | |
is_correct = response_corrected.strip() == response.strip() # No further corrections needed | |
if is_correct: | |
correct_count += 1 | |
detailed_results.append({ | |
'sentence_num': i, | |
'original': original, | |
'student_response': response, | |
'expected': expected, | |
'model_correction': response_corrected, | |
'is_correct': is_correct, | |
'analysis_html': response_analysis | |
}) | |
score = (correct_count / len(exercise_sentences)) * 100 | |
# Store attempt | |
attempt_data = { | |
"responses": responses, | |
"score": score, | |
"detailed_results": detailed_results | |
} | |
c.execute("""INSERT INTO exercise_attempts (exercise_id, student_name, responses, score) | |
VALUES (?, ?, ?, ?)""", | |
(exercise_id, student_name, json.dumps(attempt_data), score)) | |
conn.commit() | |
conn.close() | |
# Create beautiful HTML results | |
score_color = "#28a745" if score >= 70 else "#ffc107" if score >= 50 else "#dc3545" | |
feedback_html = f""" | |
<div style='font-family: Arial, sans-serif; max-width: 1000px; margin: 0 auto;'> | |
<!-- Header Section --> | |
<div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center;'> | |
<h2 style='margin: 0; font-size: 28px;'>π Exercise Results</h2> | |
<div style='margin-top: 15px; font-size: 48px; font-weight: bold; color: {score_color};'>{score:.1f}%</div> | |
<p style='margin: 10px 0 0 0; font-size: 18px; opacity: 0.9;'>{correct_count} out of {len(exercise_sentences)} sentences correct</p> | |
</div> | |
<!-- Performance Badge --> | |
<div style='background-color: #f8f9fa; padding: 20px; text-align: center; border-left: 1px solid #ddd; border-right: 1px solid #ddd;'> | |
""" | |
if score >= 90: | |
feedback_html += """<span style='background-color: #28a745; color: white; padding: 8px 20px; border-radius: 20px; font-weight: bold;'>π Excellent Work!</span>""" | |
elif score >= 70: | |
feedback_html += """<span style='background-color: #17a2b8; color: white; padding: 8px 20px; border-radius: 20px; font-weight: bold;'>π Good Job!</span>""" | |
elif score >= 50: | |
feedback_html += """<span style='background-color: #ffc107; color: white; padding: 8px 20px; border-radius: 20px; font-weight: bold;'>π Keep Practicing!</span>""" | |
else: | |
feedback_html += """<span style='background-color: #dc3545; color: white; padding: 8px 20px; border-radius: 20px; font-weight: bold;'>πͺ Try Again!</span>""" | |
feedback_html += """ | |
</div> | |
<!-- Detailed Results --> | |
<div style='background-color: white; border: 1px solid #ddd; border-radius: 0 0 10px 10px;'> | |
""" | |
for result in detailed_results: | |
# Determine colors and icons | |
if result['is_correct']: | |
border_color = "#28a745" | |
icon = "β " | |
status_bg = "#d4edda" | |
status_text = "Correct!" | |
else: | |
border_color = "#dc3545" | |
icon = "β" | |
status_bg = "#f8d7da" | |
status_text = "Needs Improvement" | |
feedback_html += f""" | |
<div style='border-left: 4px solid {border_color}; margin: 20px; padding: 20px; background-color: #fafafa; border-radius: 8px;'> | |
<div style='display: flex; align-items: center; margin-bottom: 15px;'> | |
<span style='font-size: 24px; margin-right: 10px;'>{icon}</span> | |
<h4 style='margin: 0; color: #333;'>Sentence {result['sentence_num']}</h4> | |
<span style='margin-left: auto; background-color: {status_bg}; padding: 4px 12px; border-radius: 12px; font-size: 12px; font-weight: bold;'>{status_text}</span> | |
</div> | |
<div style='margin-bottom: 15px;'> | |
<div style='margin-bottom: 10px;'> | |
<strong style='color: #6c757d;'>π Original:</strong> | |
<div style='background-color: #e9ecef; padding: 10px; border-radius: 6px; margin-top: 5px; font-style: italic;'>{result['original']}</div> | |
</div> | |
<div style='margin-bottom: 10px;'> | |
<strong style='color: #007bff;'>βοΈ Your Answer:</strong> | |
<div style='background-color: #e7f3ff; padding: 10px; border-radius: 6px; margin-top: 5px;'>{result['student_response']}</div> | |
</div> | |
""" | |
# Only show model analysis if there were errors in student's response | |
if not result['is_correct'] and result['analysis_html']: | |
feedback_html += f""" | |
<div style='margin-top: 15px; padding: 15px; background-color: #fff3cd; border-radius: 6px; border-left: 3px solid #ffc107;'> | |
<strong style='color: #856404;'>π Grammar Analysis of Your Response:</strong> | |
<div style='margin-top: 10px; font-size: 14px;'> | |
{result['analysis_html']} | |
</div> | |
</div> | |
""" | |
feedback_html += """ | |
</div> | |
</div> | |
""" | |
feedback_html += """ | |
</div> | |
<!-- Footer --> | |
<div style='text-align: center; margin-top: 30px; color: #6c757d; font-size: 14px;'> | |
<p>π‘ <strong>Tip:</strong> Review the grammar analysis above to understand common error patterns and improve your writing!</p> | |
</div> | |
</div> | |
""" | |
return f"Score: {score:.1f}%", feedback_html | |
def preview_exercise(exercise_id): | |
"""Preview an exercise before attempting it""" | |
if not exercise_id.strip(): | |
return "Please enter an exercise ID.", "" | |
try: | |
exercise_id = int(exercise_id) | |
except: | |
return "Please enter a valid exercise ID.", "" | |
# Get exercise from database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
c.execute("SELECT title, instructions, sentences FROM exercises WHERE id = ?", (exercise_id,)) | |
result = c.fetchone() | |
if not result: | |
return "Exercise not found.", "" | |
title, instructions, sentences_json = result | |
exercise_sentences = json.loads(sentences_json) | |
conn.close() | |
# Create preview HTML | |
preview_html = f""" | |
<div style='font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto;'> | |
<!-- Header --> | |
<div style='background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%); color: white; padding: 25px; border-radius: 10px 10px 0 0; text-align: center;'> | |
<h2 style='margin: 0; font-size: 24px;'>π {title}</h2> | |
<p style='margin: 10px 0 0 0; font-size: 16px; opacity: 0.9;'>Exercise ID: {exercise_id}</p> | |
</div> | |
<!-- Instructions --> | |
<div style='background-color: #e8f5e9; padding: 20px; border-left: 1px solid #ddd; border-right: 1px solid #ddd;'> | |
<h3 style='margin: 0 0 10px 0; color: #2e7d32;'>π Instructions:</h3> | |
<p style='margin: 0; font-size: 16px; line-height: 1.5;'>{instructions}</p> | |
<p style='margin: 10px 0 0 0; font-size: 14px; color: #666; font-style: italic;'> | |
π‘ Tip: Read each sentence carefully and identify grammatical errors before writing your corrections. | |
</p> | |
</div> | |
<!-- Sentences --> | |
<div style='background-color: white; border: 1px solid #ddd; border-radius: 0 0 10px 10px; padding: 20px;'> | |
<h3 style='margin: 0 0 20px 0; color: #333;'>π Sentences to Correct ({len(exercise_sentences)} total):</h3> | |
<ol style='padding-left: 20px;'> | |
""" | |
for i, sentence_data in enumerate(exercise_sentences, 1): | |
original = sentence_data['original'] | |
error_types = sentence_data.get('error_types', []) | |
# Add error type hints if available | |
error_hint = "" | |
if error_types: | |
error_hint = f"<br><small style='color: #666; font-style: italic;'>π‘ Focus on: {', '.join(error_types)}</small>" | |
preview_html += f""" | |
<li style='margin: 15px 0; padding: 15px; background-color: #f8f9fa; border-radius: 6px; border-left: 3px solid #4CAF50;'> | |
<div style='font-size: 16px; line-height: 1.5; margin-bottom: 5px;'>{original}</div> | |
{error_hint} | |
</li> | |
""" | |
preview_html += f""" | |
</ol> | |
<div style='margin-top: 30px; padding: 20px; background-color: #f0f8ff; border-radius: 8px; border: 1px solid #b3d9ff;'> | |
<h4 style='margin: 0 0 10px 0; color: #0066cc;'>π― How to Complete This Exercise:</h4> | |
<ol style='margin: 0; padding-left: 20px; color: #333;'> | |
<li>Read each sentence carefully</li> | |
<li>Identify grammatical errors (spelling, grammar, word choice, etc.)</li> | |
<li>Write your corrected version of each sentence</li> | |
<li>Enter all your answers in the text box below (one sentence per line)</li> | |
<li>Submit to get immediate feedback and scoring</li> | |
</ol> | |
</div> | |
</div> | |
</div> | |
""" | |
return f"Exercise '{title}' loaded successfully! {len(exercise_sentences)} sentences to correct.", preview_html | |
def get_student_progress(student_name): | |
"""Get student's submission and exercise history""" | |
if not student_name.strip(): | |
return "Please enter a student name." | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Get submissions | |
c.execute("""SELECT s.id, s.content, s.created_at, t.title | |
FROM submissions s JOIN tasks t ON s.task_id = t.id | |
WHERE s.student_name = ? ORDER BY s.created_at DESC""", (student_name,)) | |
submissions = c.fetchall() | |
# Get exercise attempts | |
c.execute("""SELECT ea.score, ea.created_at, e.title | |
FROM exercise_attempts ea JOIN exercises e ON ea.exercise_id = e.id | |
WHERE ea.student_name = ? ORDER BY ea.created_at DESC""", (student_name,)) | |
attempts = c.fetchall() | |
conn.close() | |
progress_html = f""" | |
<div style='font-family: Arial, sans-serif; padding: 20px;'> | |
<h3>Progress for {student_name}</h3> | |
<h4>Writing Submissions ({len(submissions)})</h4> | |
<ul> | |
""" | |
for sub in submissions: | |
progress_html += f"<li><strong>{sub[3]}</strong> - {sub[2][:16]} - {len(sub[1])} characters</li>" | |
progress_html += f""" | |
</ul> | |
<h4>Exercise Attempts ({len(attempts)})</h4> | |
<ul> | |
""" | |
for att in attempts: | |
progress_html += f"<li><strong>{att[2]}</strong> - Score: {att[0]:.1f}% - {att[1][:16]}</li>" | |
progress_html += "</ul></div>" | |
return progress_html | |
# Create Gradio Interface | |
with gr.Blocks(title="Language Learning App - Enhanced Grammar Checker", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# π Language Learning Application") | |
gr.Markdown("### AI-Powered Grammar Checking and Exercise Generation") | |
gr.Markdown("*Now featuring advanced T5-GED neural network with enhanced error detection*") | |
with gr.Tabs(): | |
# Student Writing Analysis Tab | |
with gr.TabItem("π Writing Analysis"): | |
gr.Markdown("## Submit Your Writing for Analysis") | |
with gr.Row(): | |
with gr.Column(): | |
student_name_input = gr.Textbox(label="Your Name", placeholder="Enter your name") | |
task_title_input = gr.Textbox(label="Assignment Title", value="General Writing Task") | |
writing_input = gr.Textbox( | |
label="Your Writing", | |
lines=8, | |
placeholder="Paste your writing here for grammar analysis..." | |
) | |
analyze_btn = gr.Button("Analyze Writing", variant="primary") | |
with gr.Column(): | |
corrected_output = gr.Textbox(label="Corrected Text", lines=6) | |
analysis_output = gr.HTML(label="Detailed Analysis") | |
analyze_btn.click( | |
analyze_student_writing, | |
inputs=[writing_input, student_name_input, task_title_input], | |
outputs=[corrected_output, analysis_output] | |
) | |
# Exercise Creation Tab | |
with gr.TabItem("ποΈ Exercise Creation"): | |
gr.Markdown("## Create Grammar Exercises") | |
with gr.Row(): | |
with gr.Column(): | |
exercise_title_input = gr.Textbox(label="Exercise Title", value="Grammar Exercise") | |
exercise_text_input = gr.Textbox( | |
label="Text with Errors", | |
lines=6, | |
placeholder="Enter text containing grammatical errors to create an exercise..." | |
) | |
create_exercise_btn = gr.Button("Create Exercise", variant="primary") | |
with gr.Column(): | |
exercise_result = gr.Textbox(label="Result") | |
exercise_display = gr.HTML(label="Generated Exercise") | |
create_exercise_btn.click( | |
create_exercise_from_text, | |
inputs=[exercise_text_input, exercise_title_input], | |
outputs=[exercise_result, exercise_display] | |
) | |
# Exercise Attempt Tab | |
with gr.TabItem("βοΈ Exercise Practice"): | |
gr.Markdown("## Practice Grammar Exercises") | |
with gr.Row(): | |
with gr.Column(): | |
exercise_id_input = gr.Textbox(label="Exercise ID", placeholder="Enter exercise ID") | |
# Preview section | |
with gr.Row(): | |
preview_btn = gr.Button("π Preview Exercise", variant="secondary") | |
preview_result = gr.Textbox(label="Preview Status", lines=1) | |
preview_display = gr.HTML(label="Exercise Preview") | |
# Separator | |
gr.Markdown("---") | |
# Attempt section | |
gr.Markdown("### π Complete the Exercise") | |
student_name_exercise = gr.Textbox(label="Your Name", placeholder="Enter your name") | |
responses_input = gr.Textbox( | |
label="Your Answers", | |
lines=8, | |
placeholder="After previewing the exercise above, enter your corrected sentences here (one per line)..." | |
) | |
submit_exercise_btn = gr.Button("β Submit Answers", variant="primary") | |
with gr.Column(): | |
score_output = gr.Textbox(label="Your Score") | |
feedback_output = gr.HTML(label="Detailed Feedback") | |
# Connect the buttons | |
preview_btn.click( | |
preview_exercise, | |
inputs=[exercise_id_input], | |
outputs=[preview_result, preview_display] | |
) | |
submit_exercise_btn.click( | |
attempt_exercise, | |
inputs=[exercise_id_input, responses_input, student_name_exercise], | |
outputs=[score_output, feedback_output] | |
) | |
# Progress Tracking Tab | |
with gr.TabItem("π Student Progress"): | |
gr.Markdown("## View Student Progress") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
progress_student_name = gr.Textbox(label="Student Name", placeholder="Enter student name") | |
get_progress_btn = gr.Button("Get Progress", variant="primary") | |
with gr.Column(scale=2): | |
progress_output = gr.HTML(label="Student Progress") | |
get_progress_btn.click( | |
get_student_progress, | |
inputs=[progress_student_name], | |
outputs=[progress_output] | |
) | |
gr.Markdown(""" | |
--- | |
### How to Use: | |
1. **Writing Analysis**: Submit your writing to get grammar corrections and detailed error analysis | |
2. **Exercise Creation**: Teachers can create exercises from text containing errors | |
3. **Exercise Practice**: Students can practice with generated exercises and get scored feedback | |
4. **Progress Tracking**: View student progress across submissions and exercises | |
*Powered by advanced T5-GED neural networks for enhanced grammar error detection and correction* | |
""") | |
if __name__ == "__main__": | |
app.launch(share=True) |