sashdev commited on
Commit
ec4bfe0
·
verified ·
1 Parent(s): c4b2fd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,11 +1,11 @@
 
1
  import gradio as gr
2
  import torch
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
 
5
- # Load T5 model and tokenizer
6
- model_name = "t5-base" # Use a smaller model for faster inference
7
- tokenizer = T5Tokenizer.from_pretrained(model_name)
8
- model = T5ForConditionalGeneration.from_pretrained(model_name)
9
 
10
  # Use GPU if available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -13,13 +13,14 @@ model.to(device)
13
 
14
  # Grammar correction function
15
  def correct_grammar(text):
16
- input_text = f"correct: {text}"
17
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
18
 
19
  # Generate corrected text
20
- output_ids = model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True)
21
- corrected_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
22
 
 
 
23
  return corrected_text
24
 
25
  # Gradio interface function
@@ -27,9 +28,9 @@ def correct_grammar_interface(text):
27
  corrected_text = correct_grammar(text)
28
  return corrected_text
29
 
30
- # Gradio interface
31
  with gr.Blocks() as grammar_app:
32
- gr.Markdown("<h1>Fast Grammar Correction with T5</h1>")
33
 
34
  with gr.Row():
35
  input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
@@ -37,7 +38,7 @@ with gr.Blocks() as grammar_app:
37
 
38
  submit_button = gr.Button("Correct Grammar")
39
 
40
- # Button click event
41
  submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
42
 
43
  # Launch the app
 
1
+ # Imports
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ # Load the tokenizer and model
7
+ tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
8
+ model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1")
 
9
 
10
  # Use GPU if available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
13
 
14
  # Grammar correction function
15
  def correct_grammar(text):
16
+ # Tokenize input text
17
+ inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(device)
18
 
19
  # Generate corrected text
20
+ outputs = model.generate(**inputs, max_length=512, num_beams=5, early_stopping=True)
 
21
 
22
+ # Decode the output and return the corrected text
23
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
  return corrected_text
25
 
26
  # Gradio interface function
 
28
  corrected_text = correct_grammar(text)
29
  return corrected_text
30
 
31
+ # Gradio app interface
32
  with gr.Blocks() as grammar_app:
33
+ gr.Markdown("<h1>Grammar Correction App</h1>")
34
 
35
  with gr.Row():
36
  input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
 
38
 
39
  submit_button = gr.Button("Correct Grammar")
40
 
41
+ # Bind the button click to the grammar correction function
42
  submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
43
 
44
  # Launch the app