Zlovoblachko commited on
Commit
756883e
Β·
1 Parent(s): c444d4f

initial commit

Browse files
Files changed (1) hide show
  1. app.py +324 -114
app.py CHANGED
@@ -5,8 +5,14 @@ import os
5
  from datetime import datetime
6
  import torch
7
  import nltk
8
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, ElectraTokenizer, ElectraForTokenClassification
 
 
 
 
 
9
  import torch.nn as nn
 
10
 
11
  # Download NLTK data
12
  try:
@@ -14,6 +20,304 @@ try:
14
  except LookupError:
15
  nltk.download('punkt')
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Initialize SQLite database for storing submissions and exercises
18
  def init_database():
19
  conn = sqlite3.connect('language_app.db')
@@ -74,110 +378,11 @@ def init_database():
74
  conn.commit()
75
  conn.close()
76
 
77
- # Neural Network Model (simplified version of your existing model)
78
- class SimpleGrammarChecker:
79
- def __init__(self):
80
- self.model_name = "Zlovoblachko/Realec-2step-ft-realec"
81
- self.ged_model_name = "Zlovoblachko/4tag-electra-grammar-error-detection"
82
- self.load_models()
83
-
84
- def load_models(self):
85
- try:
86
- # Load T5 model
87
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
88
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
89
-
90
- # Load GED model
91
- self.ged_tokenizer = ElectraTokenizer.from_pretrained(self.ged_model_name)
92
- self.ged_model = ElectraForTokenClassification.from_pretrained(self.ged_model_name)
93
-
94
- print("Models loaded successfully!")
95
- except Exception as e:
96
- print(f"Error loading models: {e}")
97
- self.model = None
98
- self.ged_model = None
99
-
100
- def analyze_text(self, text):
101
- if not self.model or not text.strip():
102
- return "Model not available or empty text", ""
103
-
104
- try:
105
- # Tokenize and generate correction
106
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
107
-
108
- with torch.no_grad():
109
- outputs = self.model.generate(
110
- input_ids=inputs.input_ids,
111
- attention_mask=inputs.attention_mask,
112
- max_length=512,
113
- num_beams=4,
114
- early_stopping=True
115
- )
116
-
117
- corrected_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
118
-
119
- # Get GED predictions if available
120
- error_spans = []
121
- if self.ged_model:
122
- error_spans = self.get_error_spans(text)
123
-
124
- # Generate HTML output
125
- html_output = self.generate_html_analysis(text, corrected_text, error_spans)
126
-
127
- return corrected_text, html_output
128
-
129
- except Exception as e:
130
- return f"Error during analysis: {str(e)}", ""
131
-
132
- def get_error_spans(self, text):
133
- try:
134
- inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
135
-
136
- with torch.no_grad():
137
- outputs = self.ged_model(**inputs)
138
- predictions = torch.argmax(outputs.logits, dim=2)
139
-
140
- tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
141
- token_predictions = predictions[0].cpu().numpy().tolist()
142
-
143
- error_spans = []
144
- for i, (token, pred) in enumerate(zip(tokens, token_predictions)):
145
- if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
146
- continue
147
- if pred != 0: # 0 is correct, 1=R, 2=M, 3=U
148
- error_type = ["C", "R", "M", "U"][pred]
149
- error_spans.append({"token": token, "type": error_type, "position": i})
150
-
151
- return error_spans
152
- except:
153
- return []
154
-
155
- def generate_html_analysis(self, original, corrected, error_spans):
156
- html = f"""
157
- <div style='font-family: Arial, sans-serif; line-height: 1.6; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;'>
158
- <h3 style='color: #333; margin-top: 0;'>Grammar Analysis Results</h3>
159
-
160
- <div style='margin: 15px 0;'>
161
- <h4 style='color: #555;'>Original Text:</h4>
162
- <p style='padding: 10px; background-color: #fff; border: 1px solid #ddd; border-radius: 4px;'>{original}</p>
163
- </div>
164
-
165
- <div style='margin: 15px 0;'>
166
- <h4 style='color: #28a745;'>Corrected Text:</h4>
167
- <p style='padding: 10px; background-color: #d4edda; border: 1px solid #c3e6cb; border-radius: 4px;'>{corrected}</p>
168
- </div>
169
-
170
- <div style='margin: 15px 0;'>
171
- <h4 style='color: #333;'>Error Analysis:</h4>
172
- <p style='color: #666;'>Found {len(error_spans)} potential errors</p>
173
- </div>
174
- </div>
175
- """
176
- return html
177
-
178
- # Initialize components
179
  init_database()
180
- grammar_checker = SimpleGrammarChecker()
 
 
181
 
182
  # Gradio Interface Functions
183
  def analyze_student_writing(text, student_name, task_title="General Writing Task"):
@@ -188,7 +393,7 @@ def analyze_student_writing(text, student_name, task_title="General Writing Task
188
  if not student_name.strip():
189
  return "Please enter your name.", ""
190
 
191
- # Analyze text
192
  corrected_text, html_analysis = grammar_checker.analyze_text(text)
193
 
194
  # Store in database
@@ -220,7 +425,7 @@ def analyze_student_writing(text, student_name, task_title="General Writing Task
220
  return corrected_text, html_analysis
221
 
222
  def create_exercise_from_text(text, exercise_title="Grammar Exercise"):
223
- """Create an exercise from text with errors"""
224
  if not text.strip():
225
  return "Please enter text to create an exercise.", ""
226
 
@@ -257,6 +462,7 @@ def create_exercise_from_text(text, exercise_title="Grammar Exercise"):
257
  exercise_html = f"""
258
  <div style='font-family: Arial, sans-serif; padding: 20px; border: 1px solid #ddd; border-radius: 8px;'>
259
  <h3>{exercise_title}</h3>
 
260
  <p><strong>Instructions:</strong> Correct the grammatical errors in the following sentences:</p>
261
  <ol>
262
  """
@@ -266,10 +472,10 @@ def create_exercise_from_text(text, exercise_title="Grammar Exercise"):
266
 
267
  exercise_html += "</ol></div>"
268
 
269
- return f"Exercise created with {len(exercise_sentences)} sentences!", exercise_html
270
 
271
  def attempt_exercise(exercise_id, student_responses, student_name):
272
- """Submit exercise attempt and get score"""
273
  if not student_name.strip():
274
  return "Please enter your name.", ""
275
 
@@ -296,19 +502,22 @@ def attempt_exercise(exercise_id, student_responses, student_name):
296
  if len(responses) != len(exercise_sentences):
297
  return f"Please provide exactly {len(exercise_sentences)} responses (one per line).", ""
298
 
299
- # Calculate score
300
  correct_count = 0
301
  feedback = []
302
 
303
  for i, (sentence_data, response) in enumerate(zip(exercise_sentences, responses), 1):
304
  correct_answer = sentence_data['corrected']
305
- is_correct = response.lower().strip() == correct_answer.lower().strip()
 
 
 
306
 
307
  if is_correct:
308
  correct_count += 1
309
- feedback.append(f"βœ… Sentence {i}: Correct!")
310
  else:
311
- feedback.append(f"❌ Sentence {i}: Your answer: '{response}' | Correct answer: '{correct_answer}'")
312
 
313
  score = (correct_count / len(exercise_sentences)) * 100
314
 
@@ -386,9 +595,10 @@ def get_student_progress(student_name):
386
  return progress_html
387
 
388
  # Create Gradio Interface
389
- with gr.Blocks(title="Language Learning App - Grammar Checker", theme=gr.themes.Soft()) as app:
390
  gr.Markdown("# πŸŽ“ Language Learning Application")
391
  gr.Markdown("### AI-Powered Grammar Checking and Exercise Generation")
 
392
 
393
  with gr.Tabs():
394
  # Student Writing Analysis Tab
@@ -491,7 +701,7 @@ with gr.Blocks(title="Language Learning App - Grammar Checker", theme=gr.themes.
491
  3. **Exercise Practice**: Students can practice with generated exercises and get scored feedback
492
  4. **Progress Tracking**: View student progress across submissions and exercises
493
 
494
- *Powered by advanced neural networks for grammar error detection and correction*
495
  """)
496
 
497
  if __name__ == "__main__":
 
5
  from datetime import datetime
6
  import torch
7
  import nltk
8
+ from transformers import (
9
+ T5Tokenizer,
10
+ T5ForConditionalGeneration,
11
+ ElectraTokenizer,
12
+ ElectraForTokenClassification
13
+ )
14
  import torch.nn as nn
15
+ from tqdm import tqdm
16
 
17
  # Download NLTK data
18
  try:
 
20
  except LookupError:
21
  nltk.download('punkt')
22
 
23
+ class HuggingFaceT5GEDInference:
24
+ def __init__(self, model_name="Zlovoblachko/REAlEC_2step_model_testing",
25
+ ged_model_name="Zlovoblachko/11tag-electra-grammar-stage2", device=None):
26
+ """
27
+ Initialize the inference class for T5-GED model from HuggingFace
28
+
29
+ Args:
30
+ model_name: HuggingFace model name/path for the T5-GED model
31
+ ged_model_name: HuggingFace model name/path for the GED model
32
+ device: Device to run inference on (cuda/cpu)
33
+ """
34
+ self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # Load GED model and tokenizer (same as training)
37
+ print(f"Loading GED model from HuggingFace: {ged_model_name}...")
38
+ self.ged_model, self.ged_tokenizer = self._load_ged_model(ged_model_name)
39
+
40
+ # Load T5 model and tokenizer from HuggingFace
41
+ print(f"Loading T5 model from HuggingFace: {model_name}...")
42
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(model_name)
43
+ self.t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
44
+ self.t5_model.to(self.device)
45
+
46
+ # Create GED encoder (copy of T5 encoder)
47
+ self.ged_encoder = T5ForConditionalGeneration.from_pretrained(model_name).encoder
48
+ self.ged_encoder.to(self.device)
49
+
50
+ # Create gating mechanism
51
+ encoder_hidden_size = self.t5_model.config.d_model
52
+ self.gate = nn.Linear(2 * encoder_hidden_size, 1)
53
+ self.gate.to(self.device)
54
+
55
+ # Try to load GED components from HuggingFace
56
+ try:
57
+ print("Loading GED components...")
58
+ from huggingface_hub import hf_hub_download
59
+ ged_components_path = hf_hub_download(
60
+ repo_id=model_name,
61
+ filename="ged_components.pt",
62
+ cache_dir=None
63
+ )
64
+ ged_components = torch.load(ged_components_path, map_location=self.device)
65
+ self.ged_encoder.load_state_dict(ged_components["ged_encoder"])
66
+ self.gate.load_state_dict(ged_components["gate"])
67
+ print("GED components loaded successfully!")
68
+ except Exception as e:
69
+ print(f"Warning: Could not load GED components: {e}")
70
+ print("Using default initialization for GED encoder and gate.")
71
+
72
+ # Set to evaluation mode
73
+ self.t5_model.eval()
74
+ self.ged_encoder.eval()
75
+ self.gate.eval()
76
+
77
+ def _load_ged_model(self, model_name):
78
+ """Load GED model and tokenizer from HuggingFace"""
79
+ tokenizer = ElectraTokenizer.from_pretrained(model_name)
80
+ model = ElectraForTokenClassification.from_pretrained(model_name)
81
+ model.to(self.device)
82
+ model.eval()
83
+ return model, tokenizer
84
+
85
+ def _get_ged_predictions(self, text):
86
+ """Get GED predictions for input text - exact same as training preprocessing"""
87
+ inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device)
88
+ with torch.no_grad():
89
+ outputs = self.ged_model(**inputs)
90
+ logits = outputs.logits
91
+ predictions = torch.argmax(logits, dim=2)
92
+ token_predictions = predictions[0].cpu().numpy().tolist()
93
+ tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
94
+
95
+ ged_tags = []
96
+ for token, pred in zip(tokens, token_predictions):
97
+ if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
98
+ continue
99
+ ged_tags.append(str(pred))
100
+
101
+ return " ".join(ged_tags), tokens, token_predictions
102
+
103
+ def _get_error_spans(self, text):
104
+ """Extract error spans with simplified categories for display"""
105
+ ged_tags_str, tokens, predictions = self._get_ged_predictions(text)
106
+
107
+ error_spans = []
108
+ clean_tokens = []
109
+
110
+ for token, pred in zip(tokens, predictions):
111
+ if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
112
+ continue
113
+ clean_tokens.append(token)
114
+
115
+ if pred != 0: # 0 is correct, others are various error types
116
+ # Simplify the 11-tag system to basic categories for user display
117
+ if pred in [1, 2, 3, 4]: # Various replacement/substitution errors
118
+ error_type = "Grammar"
119
+ elif pred in [5, 6]: # Missing elements
120
+ error_type = "Missing"
121
+ elif pred in [7, 8]: # Unnecessary elements
122
+ error_type = "Unnecessary"
123
+ elif pred in [9, 10]: # Other error types
124
+ error_type = "Usage"
125
+ else:
126
+ error_type = "Error"
127
+
128
+ error_spans.append({
129
+ "token": token,
130
+ "type": error_type,
131
+ "position": len(clean_tokens) - 1
132
+ })
133
+
134
+ return error_spans
135
+
136
+ def _preprocess_inputs(self, text, max_length=128):
137
+ """Preprocess input text exactly as during training"""
138
+ # Get GED predictions
139
+ ged_tags, _, _ = self._get_ged_predictions(text)
140
+
141
+ # Tokenize source text (same as training)
142
+ src_tokens = self.t5_tokenizer(
143
+ text,
144
+ truncation=True,
145
+ max_length=max_length,
146
+ return_tensors="pt"
147
+ )
148
+
149
+ # Tokenize GED tags (same as training)
150
+ ged_tokens = self.t5_tokenizer(
151
+ ged_tags,
152
+ truncation=True,
153
+ max_length=max_length,
154
+ return_tensors="pt"
155
+ )
156
+
157
+ return {
158
+ "input_ids": src_tokens.input_ids.to(self.device),
159
+ "attention_mask": src_tokens.attention_mask.to(self.device),
160
+ "ged_input_ids": ged_tokens.input_ids.to(self.device),
161
+ "ged_attention_mask": ged_tokens.attention_mask.to(self.device)
162
+ }
163
+
164
+ def _forward_with_ged(self, input_ids, attention_mask, ged_input_ids, ged_attention_mask, max_length=200):
165
+ """
166
+ Forward pass with GED integration - replicates T5WithGED.forward() logic
167
+ """
168
+ # Get source encoder outputs
169
+ src_encoder_outputs = self.t5_model.encoder(
170
+ input_ids=input_ids,
171
+ attention_mask=attention_mask,
172
+ return_dict=True
173
+ )
174
+
175
+ # Get GED encoder outputs
176
+ ged_encoder_outputs = self.ged_encoder(
177
+ input_ids=ged_input_ids,
178
+ attention_mask=ged_attention_mask,
179
+ return_dict=True
180
+ )
181
+
182
+ # Get hidden states
183
+ src_hidden_states = src_encoder_outputs.last_hidden_state
184
+ ged_hidden_states = ged_encoder_outputs.last_hidden_state
185
+
186
+ # Combine hidden states (same as training)
187
+ min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1))
188
+ combined = torch.cat([
189
+ src_hidden_states[:, :min_len, :],
190
+ ged_hidden_states[:, :min_len, :]
191
+ ], dim=2)
192
+
193
+ # Apply gating mechanism
194
+ gate_scores = torch.sigmoid(self.gate(combined))
195
+ combined_hidden = (
196
+ gate_scores * src_hidden_states[:, :min_len, :] +
197
+ (1 - gate_scores) * ged_hidden_states[:, :min_len, :]
198
+ )
199
+
200
+ # Update encoder outputs
201
+ src_encoder_outputs.last_hidden_state = combined_hidden
202
+
203
+ # Generate using T5 decoder
204
+ decoder_outputs = self.t5_model.generate(
205
+ encoder_outputs=src_encoder_outputs,
206
+ max_length=max_length,
207
+ do_sample=False,
208
+ num_beams=1
209
+ )
210
+
211
+ return decoder_outputs
212
+
213
+ def correct_text(self, text, max_length=200):
214
+ """
215
+ Correct grammatical errors in input text
216
+
217
+ Args:
218
+ text: Input text to correct
219
+ max_length: Maximum length for generation
220
+
221
+ Returns:
222
+ Corrected text as string
223
+ """
224
+ # Preprocess inputs exactly as training
225
+ inputs = self._preprocess_inputs(text)
226
+
227
+ # Generate correction using GED-enhanced model
228
+ with torch.no_grad():
229
+ generated_ids = self._forward_with_ged(
230
+ input_ids=inputs["input_ids"],
231
+ attention_mask=inputs["attention_mask"],
232
+ ged_input_ids=inputs["ged_input_ids"],
233
+ ged_attention_mask=inputs["ged_attention_mask"],
234
+ max_length=max_length
235
+ )
236
+
237
+ # Decode output
238
+ corrected_text = self.t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
239
+ return corrected_text
240
+
241
+ def analyze_text(self, text):
242
+ """Enhanced analysis method for Gradio integration"""
243
+ if not text.strip():
244
+ return "Model not available or empty text", ""
245
+
246
+ try:
247
+ # Get corrected text
248
+ corrected_text = self.correct_text(text)
249
+
250
+ # Get error spans
251
+ error_spans = self._get_error_spans(text)
252
+
253
+ # Generate HTML output
254
+ html_output = self.generate_html_analysis(text, corrected_text, error_spans)
255
+
256
+ return corrected_text, html_output
257
+
258
+ except Exception as e:
259
+ return f"Error during analysis: {str(e)}", ""
260
+
261
+ def generate_html_analysis(self, original, corrected, error_spans):
262
+ """Generate enhanced HTML analysis output"""
263
+ # Create highlighted original text
264
+ highlighted_original = original
265
+ if error_spans:
266
+ # Sort by position in reverse to avoid index shifting
267
+ sorted_spans = sorted(error_spans, key=lambda x: x['position'], reverse=True)
268
+
269
+ # Simple highlighting - in a more sophisticated version, you'd map token positions to character positions
270
+ for span in sorted_spans:
271
+ token = span['token']
272
+ error_type = span['type']
273
+
274
+ # Color coding for different error types
275
+ color_map = {
276
+ "Grammar": "#ffebee", # Light red
277
+ "Missing": "#e8f5e8", # Light green
278
+ "Unnecessary": "#fff3e0", # Light orange
279
+ "Usage": "#e3f2fd" # Light blue
280
+ }
281
+
282
+ color = color_map.get(error_type, "#f5f5f5")
283
+
284
+ # Simple token replacement (basic highlighting)
285
+ if token in highlighted_original:
286
+ highlighted_original = highlighted_original.replace(
287
+ token,
288
+ f"<span style='background-color: {color}; padding: 1px 3px; border-radius: 3px; margin: 0 1px;' title='{error_type}'>{token}</span>",
289
+ 1
290
+ )
291
+
292
+ html = f"""
293
+ <div style='font-family: Arial, sans-serif; line-height: 1.6; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;'>
294
+ <h3 style='color: #333; margin-top: 0;'>Grammar Analysis Results</h3>
295
+
296
+ <div style='margin: 15px 0;'>
297
+ <h4 style='color: #555;'>Original Text with Error Highlighting:</h4>
298
+ <div style='padding: 10px; background-color: #fff; border: 1px solid #ddd; border-radius: 4px;'>{highlighted_original}</div>
299
+ </div>
300
+
301
+ <div style='margin: 15px 0;'>
302
+ <h4 style='color: #28a745;'>Corrected Text:</h4>
303
+ <p style='padding: 10px; background-color: #d4edda; border: 1px solid #c3e6cb; border-radius: 4px;'>{corrected}</p>
304
+ </div>
305
+
306
+ <div style='margin: 15px 0;'>
307
+ <h4 style='color: #333;'>Error Summary:</h4>
308
+ <p style='color: #666;'>Found {len(error_spans)} potential issues</p>
309
+
310
+ <div style='margin-top: 10px;'>
311
+ <span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #ffebee; border-radius: 12px; font-size: 12px;'>Grammar</span>
312
+ <span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #e8f5e8; border-radius: 12px; font-size: 12px;'>Missing</span>
313
+ <span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #fff3e0; border-radius: 12px; font-size: 12px;'>Unnecessary</span>
314
+ <span style='display: inline-block; margin: 2px 5px; padding: 2px 8px; background-color: #e3f2fd; border-radius: 12px; font-size: 12px;'>Usage</span>
315
+ </div>
316
+ </div>
317
+ </div>
318
+ """
319
+ return html
320
+
321
  # Initialize SQLite database for storing submissions and exercises
322
  def init_database():
323
  conn = sqlite3.connect('language_app.db')
 
378
  conn.commit()
379
  conn.close()
380
 
381
+ # Initialize database and components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  init_database()
383
+ print("Initializing enhanced grammar checker...")
384
+ grammar_checker = HuggingFaceT5GEDInference()
385
+ print("Grammar checker initialized successfully!")
386
 
387
  # Gradio Interface Functions
388
  def analyze_student_writing(text, student_name, task_title="General Writing Task"):
 
393
  if not student_name.strip():
394
  return "Please enter your name.", ""
395
 
396
+ # Analyze text with enhanced model
397
  corrected_text, html_analysis = grammar_checker.analyze_text(text)
398
 
399
  # Store in database
 
425
  return corrected_text, html_analysis
426
 
427
  def create_exercise_from_text(text, exercise_title="Grammar Exercise"):
428
+ """Create an exercise from text with errors using enhanced analysis"""
429
  if not text.strip():
430
  return "Please enter text to create an exercise.", ""
431
 
 
462
  exercise_html = f"""
463
  <div style='font-family: Arial, sans-serif; padding: 20px; border: 1px solid #ddd; border-radius: 8px;'>
464
  <h3>{exercise_title}</h3>
465
+ <p><strong>Exercise ID: {exercise_id}</strong></p>
466
  <p><strong>Instructions:</strong> Correct the grammatical errors in the following sentences:</p>
467
  <ol>
468
  """
 
472
 
473
  exercise_html += "</ol></div>"
474
 
475
+ return f"Exercise created with {len(exercise_sentences)} sentences! Exercise ID: {exercise_id}", exercise_html
476
 
477
  def attempt_exercise(exercise_id, student_responses, student_name):
478
+ """Submit exercise attempt and get score using enhanced analysis"""
479
  if not student_name.strip():
480
  return "Please enter your name.", ""
481
 
 
502
  if len(responses) != len(exercise_sentences):
503
  return f"Please provide exactly {len(exercise_sentences)} responses (one per line).", ""
504
 
505
+ # Calculate score using enhanced analysis
506
  correct_count = 0
507
  feedback = []
508
 
509
  for i, (sentence_data, response) in enumerate(zip(exercise_sentences, responses), 1):
510
  correct_answer = sentence_data['corrected']
511
+
512
+ # Use the model to check if the response is correct
513
+ response_corrected, _ = grammar_checker.analyze_text(response)
514
+ is_correct = response_corrected.strip() == response.strip() # No further corrections needed
515
 
516
  if is_correct:
517
  correct_count += 1
518
+ feedback.append(f"βœ… Sentence {i}: Excellent! No errors detected.")
519
  else:
520
+ feedback.append(f"❌ Sentence {i}: Your answer: '{response}' | Suggested improvement: '{response_corrected}' | Expected: '{correct_answer}'")
521
 
522
  score = (correct_count / len(exercise_sentences)) * 100
523
 
 
595
  return progress_html
596
 
597
  # Create Gradio Interface
598
+ with gr.Blocks(title="Language Learning App - Enhanced Grammar Checker", theme=gr.themes.Soft()) as app:
599
  gr.Markdown("# πŸŽ“ Language Learning Application")
600
  gr.Markdown("### AI-Powered Grammar Checking and Exercise Generation")
601
+ gr.Markdown("*Now featuring advanced T5-GED neural network with enhanced error detection*")
602
 
603
  with gr.Tabs():
604
  # Student Writing Analysis Tab
 
701
  3. **Exercise Practice**: Students can practice with generated exercises and get scored feedback
702
  4. **Progress Tracking**: View student progress across submissions and exercises
703
 
704
+ *Powered by advanced T5-GED neural networks for enhanced grammar error detection and correction*
705
  """)
706
 
707
  if __name__ == "__main__":