sashdev commited on
Commit
b334778
·
verified ·
1 Parent(s): 19f79d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -1,5 +1,7 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
 
 
3
 
4
  # Load model and tokenizer
5
  model_name = "hassaanik/grammar-correction-model"
@@ -13,21 +15,36 @@ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
13
  if torch.cuda.is_available():
14
  model.half()
15
 
16
- # Function to correct grammar for a single text input
17
- def correct_grammar(text):
18
  # Tokenize input and move it to the correct device (CPU/GPU)
19
  inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)
20
-
21
- # Generate corrected output with beam search
22
- outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
23
 
24
  # Decode output and return corrected text
25
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
  return corrected_text
27
 
28
- # Example usage of the grammar correction function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if __name__ == "__main__":
30
- sample_text = "He go to the market yesturday."
31
- corrected_text = correct_grammar(sample_text)
32
- print("Original Text:", sample_text)
33
- print("Corrected Text:", corrected_text)
 
1
+ import gradio as gr
2
  import torch
3
+ import asyncio
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  # Load model and tokenizer
7
  model_name = "hassaanik/grammar-correction-model"
 
15
  if torch.cuda.is_available():
16
  model.half()
17
 
18
+ # Async grammar correction function
19
+ async def correct_grammar_async(text):
20
  # Tokenize input and move it to the correct device (CPU/GPU)
21
  inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)
22
+
23
+ # Asynchronous operation to run grammar correction
24
+ outputs = await asyncio.to_thread(model.generate, inputs, max_length=512, num_beams=5, early_stopping=True)
25
 
26
  # Decode output and return corrected text
27
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return corrected_text
29
 
30
+ # Gradio interface function to handle input and output
31
+ def correct_grammar_interface(text):
32
+ corrected_text = asyncio.run(correct_grammar_async(text))
33
+ return corrected_text
34
+
35
+ # Create Gradio Interface
36
+ with gr.Blocks() as grammar_app:
37
+ gr.Markdown("<h1>Async Grammar Correction App</h1>")
38
+
39
+ with gr.Row():
40
+ input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
41
+ output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=4)
42
+
43
+ submit_button = gr.Button("Correct Grammar")
44
+
45
+ # When the button is clicked, run the correction process
46
+ submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
47
+
48
+ # Launch the app
49
  if __name__ == "__main__":
50
+ grammar_app.launch()