Spaces:
Running
Running
import gradio as gr | |
import sqlite3 | |
import json | |
import os | |
from datetime import datetime | |
import torch | |
import nltk | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, ElectraTokenizer, ElectraForTokenClassification | |
import torch.nn as nn | |
# Download NLTK data | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
# Initialize SQLite database for storing submissions and exercises | |
def init_database(): | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Users table | |
c.execute('''CREATE TABLE IF NOT EXISTS users ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
username TEXT UNIQUE NOT NULL, | |
email TEXT UNIQUE NOT NULL, | |
role TEXT NOT NULL, | |
password_hash TEXT NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Tasks table | |
c.execute('''CREATE TABLE IF NOT EXISTS tasks ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
title TEXT NOT NULL, | |
description TEXT NOT NULL, | |
image_url TEXT, | |
creator_id INTEGER, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Submissions table | |
c.execute('''CREATE TABLE IF NOT EXISTS submissions ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
task_id INTEGER, | |
student_name TEXT NOT NULL, | |
content TEXT NOT NULL, | |
analysis_result TEXT, | |
analysis_html TEXT, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Exercises table | |
c.execute('''CREATE TABLE IF NOT EXISTS exercises ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
title TEXT NOT NULL, | |
instructions TEXT NOT NULL, | |
sentences TEXT NOT NULL, | |
image_url TEXT, | |
submission_id INTEGER, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
# Exercise attempts table | |
c.execute('''CREATE TABLE IF NOT EXISTS exercise_attempts ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
exercise_id INTEGER, | |
student_name TEXT NOT NULL, | |
responses TEXT NOT NULL, | |
score REAL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
)''') | |
conn.commit() | |
conn.close() | |
# Neural Network Model (simplified version of your existing model) | |
class SimpleGrammarChecker: | |
def __init__(self): | |
self.model_name = "Zlovoblachko/Realec-2step-ft-realec" | |
self.ged_model_name = "Zlovoblachko/4tag-electra-grammar-error-detection" | |
self.load_models() | |
def load_models(self): | |
try: | |
# Load T5 model | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
# Load GED model | |
self.ged_tokenizer = ElectraTokenizer.from_pretrained(self.ged_model_name) | |
self.ged_model = ElectraForTokenClassification.from_pretrained(self.ged_model_name) | |
print("Models loaded successfully!") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
self.model = None | |
self.ged_model = None | |
def analyze_text(self, text): | |
if not self.model or not text.strip(): | |
return "Model not available or empty text", "" | |
try: | |
# Tokenize and generate correction | |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
corrected_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Get GED predictions if available | |
error_spans = [] | |
if self.ged_model: | |
error_spans = self.get_error_spans(text) | |
# Generate HTML output | |
html_output = self.generate_html_analysis(text, corrected_text, error_spans) | |
return corrected_text, html_output | |
except Exception as e: | |
return f"Error during analysis: {str(e)}", "" | |
def get_error_spans(self, text): | |
try: | |
inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = self.ged_model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=2) | |
tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
token_predictions = predictions[0].cpu().numpy().tolist() | |
error_spans = [] | |
for i, (token, pred) in enumerate(zip(tokens, token_predictions)): | |
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
if pred != 0: # 0 is correct, 1=R, 2=M, 3=U | |
error_type = ["C", "R", "M", "U"][pred] | |
error_spans.append({"token": token, "type": error_type, "position": i}) | |
return error_spans | |
except: | |
return [] | |
def generate_html_analysis(self, original, corrected, error_spans): | |
html = f""" | |
<div style='font-family: Arial, sans-serif; line-height: 1.6; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;'> | |
<h3 style='color: #333; margin-top: 0;'>Grammar Analysis Results</h3> | |
<div style='margin: 15px 0;'> | |
<h4 style='color: #555;'>Original Text:</h4> | |
<p style='padding: 10px; background-color: #fff; border: 1px solid #ddd; border-radius: 4px;'>{original}</p> | |
</div> | |
<div style='margin: 15px 0;'> | |
<h4 style='color: #28a745;'>Corrected Text:</h4> | |
<p style='padding: 10px; background-color: #d4edda; border: 1px solid #c3e6cb; border-radius: 4px;'>{corrected}</p> | |
</div> | |
<div style='margin: 15px 0;'> | |
<h4 style='color: #333;'>Error Analysis:</h4> | |
<p style='color: #666;'>Found {len(error_spans)} potential errors</p> | |
</div> | |
</div> | |
""" | |
return html | |
# Initialize components | |
init_database() | |
grammar_checker = SimpleGrammarChecker() | |
# Gradio Interface Functions | |
def analyze_student_writing(text, student_name, task_title="General Writing Task"): | |
"""Analyze student writing and store in database""" | |
if not text.strip(): | |
return "Please enter some text to analyze.", "" | |
if not student_name.strip(): | |
return "Please enter your name.", "" | |
# Analyze text | |
corrected_text, html_analysis = grammar_checker.analyze_text(text) | |
# Store in database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Insert task if not exists | |
c.execute("INSERT OR IGNORE INTO tasks (title, description) VALUES (?, ?)", | |
(task_title, f"Writing task: {task_title}")) | |
c.execute("SELECT id FROM tasks WHERE title = ?", (task_title,)) | |
task_id = c.fetchone()[0] | |
# Insert submission | |
analysis_data = { | |
"corrected_text": corrected_text, | |
"original_text": text, | |
"html_output": html_analysis | |
} | |
c.execute("""INSERT INTO submissions (task_id, student_name, content, analysis_result, analysis_html) | |
VALUES (?, ?, ?, ?, ?)""", | |
(task_id, student_name, text, json.dumps(analysis_data), html_analysis)) | |
submission_id = c.lastrowid | |
conn.commit() | |
conn.close() | |
return corrected_text, html_analysis | |
def create_exercise_from_text(text, exercise_title="Grammar Exercise"): | |
"""Create an exercise from text with errors""" | |
if not text.strip(): | |
return "Please enter text to create an exercise.", "" | |
# Analyze text to find sentences with errors | |
sentences = nltk.sent_tokenize(text) | |
exercise_sentences = [] | |
for sentence in sentences: | |
corrected, _ = grammar_checker.analyze_text(sentence) | |
if sentence.strip() != corrected.strip(): # Has errors | |
exercise_sentences.append({ | |
"original": sentence.strip(), | |
"corrected": corrected.strip() | |
}) | |
if not exercise_sentences: | |
return "No errors found in the text. Cannot create exercise.", "" | |
# Store exercise in database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
c.execute("""INSERT INTO exercises (title, instructions, sentences) | |
VALUES (?, ?, ?)""", | |
(exercise_title, | |
"Correct the grammatical errors in the following sentences:", | |
json.dumps(exercise_sentences))) | |
exercise_id = c.lastrowid | |
conn.commit() | |
conn.close() | |
# Generate exercise HTML | |
exercise_html = f""" | |
<div style='font-family: Arial, sans-serif; padding: 20px; border: 1px solid #ddd; border-radius: 8px;'> | |
<h3>{exercise_title}</h3> | |
<p><strong>Instructions:</strong> Correct the grammatical errors in the following sentences:</p> | |
<ol> | |
""" | |
for i, sentence_data in enumerate(exercise_sentences, 1): | |
exercise_html += f"<li style='margin: 10px 0; padding: 10px; background-color: #f8f9fa; border-radius: 4px;'>{sentence_data['original']}</li>" | |
exercise_html += "</ol></div>" | |
return f"Exercise created with {len(exercise_sentences)} sentences!", exercise_html | |
def attempt_exercise(exercise_id, student_responses, student_name): | |
"""Submit exercise attempt and get score""" | |
if not student_name.strip(): | |
return "Please enter your name.", "" | |
try: | |
exercise_id = int(exercise_id) | |
except: | |
return "Please enter a valid exercise ID.", "" | |
# Get exercise from database | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
c.execute("SELECT sentences FROM exercises WHERE id = ?", (exercise_id,)) | |
result = c.fetchone() | |
if not result: | |
return "Exercise not found.", "" | |
exercise_sentences = json.loads(result[0]) | |
# Parse student responses | |
responses = [r.strip() for r in student_responses.split('\n') if r.strip()] | |
if len(responses) != len(exercise_sentences): | |
return f"Please provide exactly {len(exercise_sentences)} responses (one per line).", "" | |
# Calculate score | |
correct_count = 0 | |
feedback = [] | |
for i, (sentence_data, response) in enumerate(zip(exercise_sentences, responses), 1): | |
correct_answer = sentence_data['corrected'] | |
is_correct = response.lower().strip() == correct_answer.lower().strip() | |
if is_correct: | |
correct_count += 1 | |
feedback.append(f"β Sentence {i}: Correct!") | |
else: | |
feedback.append(f"β Sentence {i}: Your answer: '{response}' | Correct answer: '{correct_answer}'") | |
score = (correct_count / len(exercise_sentences)) * 100 | |
# Store attempt | |
attempt_data = { | |
"responses": responses, | |
"score": score, | |
"feedback": feedback | |
} | |
c.execute("""INSERT INTO exercise_attempts (exercise_id, student_name, responses, score) | |
VALUES (?, ?, ?, ?)""", | |
(exercise_id, student_name, json.dumps(attempt_data), score)) | |
conn.commit() | |
conn.close() | |
feedback_html = f""" | |
<div style='font-family: Arial, sans-serif; padding: 20px; border: 1px solid #ddd; border-radius: 8px;'> | |
<h3>Exercise Results</h3> | |
<p><strong>Score: {score:.1f}%</strong> ({correct_count}/{len(exercise_sentences)} correct)</p> | |
<div style='margin-top: 15px;'> | |
{'<br>'.join(feedback)} | |
</div> | |
</div> | |
""" | |
return f"Score: {score:.1f}%", feedback_html | |
def get_student_progress(student_name): | |
"""Get student's submission and exercise history""" | |
if not student_name.strip(): | |
return "Please enter a student name." | |
conn = sqlite3.connect('language_app.db') | |
c = conn.cursor() | |
# Get submissions | |
c.execute("""SELECT s.id, s.content, s.created_at, t.title | |
FROM submissions s JOIN tasks t ON s.task_id = t.id | |
WHERE s.student_name = ? ORDER BY s.created_at DESC""", (student_name,)) | |
submissions = c.fetchall() | |
# Get exercise attempts | |
c.execute("""SELECT ea.score, ea.created_at, e.title | |
FROM exercise_attempts ea JOIN exercises e ON ea.exercise_id = e.id | |
WHERE ea.student_name = ? ORDER BY ea.created_at DESC""", (student_name,)) | |
attempts = c.fetchall() | |
conn.close() | |
progress_html = f""" | |
<div style='font-family: Arial, sans-serif; padding: 20px;'> | |
<h3>Progress for {student_name}</h3> | |
<h4>Writing Submissions ({len(submissions)})</h4> | |
<ul> | |
""" | |
for sub in submissions: | |
progress_html += f"<li><strong>{sub[3]}</strong> - {sub[2][:16]} - {len(sub[1])} characters</li>" | |
progress_html += f""" | |
</ul> | |
<h4>Exercise Attempts ({len(attempts)})</h4> | |
<ul> | |
""" | |
for att in attempts: | |
progress_html += f"<li><strong>{att[2]}</strong> - Score: {att[0]:.1f}% - {att[1][:16]}</li>" | |
progress_html += "</ul></div>" | |
return progress_html | |
# Create Gradio Interface | |
with gr.Blocks(title="Language Learning App - Grammar Checker", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# π Language Learning Application") | |
gr.Markdown("### AI-Powered Grammar Checking and Exercise Generation") | |
with gr.Tabs(): | |
# Student Writing Analysis Tab | |
with gr.TabItem("π Writing Analysis"): | |
gr.Markdown("## Submit Your Writing for Analysis") | |
with gr.Row(): | |
with gr.Column(): | |
student_name_input = gr.Textbox(label="Your Name", placeholder="Enter your name") | |
task_title_input = gr.Textbox(label="Assignment Title", value="General Writing Task") | |
writing_input = gr.Textbox( | |
label="Your Writing", | |
lines=8, | |
placeholder="Paste your writing here for grammar analysis..." | |
) | |
analyze_btn = gr.Button("Analyze Writing", variant="primary") | |
with gr.Column(): | |
corrected_output = gr.Textbox(label="Corrected Text", lines=6) | |
analysis_output = gr.HTML(label="Detailed Analysis") | |
analyze_btn.click( | |
analyze_student_writing, | |
inputs=[writing_input, student_name_input, task_title_input], | |
outputs=[corrected_output, analysis_output] | |
) | |
# Exercise Creation Tab | |
with gr.TabItem("ποΈ Exercise Creation"): | |
gr.Markdown("## Create Grammar Exercises") | |
with gr.Row(): | |
with gr.Column(): | |
exercise_title_input = gr.Textbox(label="Exercise Title", value="Grammar Exercise") | |
exercise_text_input = gr.Textbox( | |
label="Text with Errors", | |
lines=6, | |
placeholder="Enter text containing grammatical errors to create an exercise..." | |
) | |
create_exercise_btn = gr.Button("Create Exercise", variant="primary") | |
with gr.Column(): | |
exercise_result = gr.Textbox(label="Result") | |
exercise_display = gr.HTML(label="Generated Exercise") | |
create_exercise_btn.click( | |
create_exercise_from_text, | |
inputs=[exercise_text_input, exercise_title_input], | |
outputs=[exercise_result, exercise_display] | |
) | |
# Exercise Attempt Tab | |
with gr.TabItem("βοΈ Exercise Practice"): | |
gr.Markdown("## Practice Grammar Exercises") | |
with gr.Row(): | |
with gr.Column(): | |
exercise_id_input = gr.Textbox(label="Exercise ID", placeholder="Enter exercise ID") | |
student_name_exercise = gr.Textbox(label="Your Name", placeholder="Enter your name") | |
responses_input = gr.Textbox( | |
label="Your Answers", | |
lines=6, | |
placeholder="Enter your corrected sentences (one per line)..." | |
) | |
submit_exercise_btn = gr.Button("Submit Answers", variant="primary") | |
with gr.Column(): | |
score_output = gr.Textbox(label="Your Score") | |
feedback_output = gr.HTML(label="Detailed Feedback") | |
submit_exercise_btn.click( | |
attempt_exercise, | |
inputs=[exercise_id_input, responses_input, student_name_exercise], | |
outputs=[score_output, feedback_output] | |
) | |
# Progress Tracking Tab | |
with gr.TabItem("π Student Progress"): | |
gr.Markdown("## View Student Progress") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
progress_student_name = gr.Textbox(label="Student Name", placeholder="Enter student name") | |
get_progress_btn = gr.Button("Get Progress", variant="primary") | |
with gr.Column(scale=2): | |
progress_output = gr.HTML(label="Student Progress") | |
get_progress_btn.click( | |
get_student_progress, | |
inputs=[progress_student_name], | |
outputs=[progress_output] | |
) | |
gr.Markdown(""" | |
--- | |
### How to Use: | |
1. **Writing Analysis**: Submit your writing to get grammar corrections and detailed error analysis | |
2. **Exercise Creation**: Teachers can create exercises from text containing errors | |
3. **Exercise Practice**: Students can practice with generated exercises and get scored feedback | |
4. **Progress Tracking**: View student progress across submissions and exercises | |
*Powered by advanced neural networks for grammar error detection and correction* | |
""") | |
if __name__ == "__main__": | |
app.launch() |