sashdev commited on
Commit
c2f522f
·
verified ·
1 Parent(s): ac1c8c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -25
app.py CHANGED
@@ -1,48 +1,44 @@
 
1
  import gradio as gr
2
  import torch
3
- import asyncio
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
- # Load model and tokenizer
7
- model_name = "hassaanik/grammar-correction-model"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
- # Use GPU if available, otherwise fall back to CPU
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
13
 
14
- # Use FP16 for faster inference on GPU
15
- if torch.cuda.is_available():
16
- model.half()
17
-
18
- # Async grammar correction function
19
- async def correct_grammar_async(text):
20
- # Tokenize input and move it to the correct device (CPU/GPU)
21
- inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)
22
-
23
- # Asynchronous operation to run grammar correction
24
- outputs = await asyncio.to_thread(model.generate, inputs, max_length=512, num_beams=5, early_stopping=True)
25
 
26
- # Decode output and return corrected text
27
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return corrected_text
29
 
30
- # Gradio interface function to handle input and output
31
  def correct_grammar_interface(text):
32
- corrected_text = asyncio.run(correct_grammar_async(text))
33
  return corrected_text
34
 
35
- # Create Gradio Interface
36
  with gr.Blocks() as grammar_app:
37
- gr.Markdown("<h1>Async Grammar Correction App</h1>")
38
 
39
  with gr.Row():
40
- input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
41
- output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=4)
42
 
43
  submit_button = gr.Button("Correct Grammar")
44
 
45
- # When the button is clicked, run the correction process
46
  submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
47
 
48
  # Launch the app
 
1
+ # Imports
2
  import gradio as gr
3
  import torch
 
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ # Load the tokenizer and model
7
+ tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
8
+ model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1")
9
 
10
+ # Use GPU if available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model.to(device)
13
 
14
+ # Grammar correction function
15
+ def correct_grammar(text):
16
+ # Tokenize input text with an increased max_length for handling larger input
17
+ inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
18
+
19
+ # Generate corrected text with increased max_length and num_beams
20
+ outputs = model.generate(**inputs, max_length=1024, num_beams=5, early_stopping=True)
 
 
 
 
21
 
22
+ # Decode the output and return the corrected text
23
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
  return corrected_text
25
 
26
+ # Gradio interface function
27
  def correct_grammar_interface(text):
28
+ corrected_text = correct_grammar(text)
29
  return corrected_text
30
 
31
+ # Gradio app interface
32
  with gr.Blocks() as grammar_app:
33
+ gr.Markdown("<h1>Grammar Correction App (up to 300 words)</h1>")
34
 
35
  with gr.Row():
36
+ input_box = gr.Textbox(label="Input Text", placeholder="Enter text (up to 300 words)", lines=10)
37
+ output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=10)
38
 
39
  submit_button = gr.Button("Correct Grammar")
40
 
41
+ # Bind the button click to the grammar correction function
42
  submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
43
 
44
  # Launch the app