sashdev commited on
Commit
b5d0fef
·
verified ·
1 Parent(s): 97c782c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -18
app.py CHANGED
@@ -1,20 +1,50 @@
1
- from transformers import T5Tokenizer, T5ForConditionalGeneration
2
  import torch
 
 
3
 
4
- # Load the T5 tokenizer and model
5
- model_name = "t5-small" # You can use any T5 model available
6
- tokenizer = T5Tokenizer.from_pretrained(model_name)
7
- model = T5ForConditionalGeneration.from_pretrained(model_name)
8
-
9
- # Example function to use the model
10
- def summarize(text):
11
- # Tokenize the input text
12
- inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
13
- # Generate summary
14
- outputs = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
15
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
16
- return summary
17
-
18
- # Example usage
19
- text_to_summarize = "Your input text goes here."
20
- print(summarize(text_to_summarize))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ import asyncio
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ # Load model and tokenizer
7
+ model_name = "hassaanik/grammar-correction-model"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+
10
+ # Use GPU if available, otherwise fall back to CPU
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
13
+
14
+ # Use FP16 for faster inference on GPU
15
+ if torch.cuda.is_available():
16
+ model.half()
17
+
18
+ # Async grammar correction function
19
+ async def correct_grammar_async(text):
20
+ # Tokenize input and move it to the correct device (CPU/GPU)
21
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)
22
+
23
+ # Asynchronous operation to run grammar correction
24
+ outputs = await asyncio.to_thread(model.generate, inputs, max_length=512, num_beams=5, early_stopping=True)
25
+
26
+ # Decode output and return corrected text
27
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ return corrected_text
29
+
30
+ # Gradio interface function to handle input and output
31
+ def correct_grammar_interface(text):
32
+ corrected_text = asyncio.run(correct_grammar_async(text))
33
+ return corrected_text
34
+
35
+ # Create Gradio Interface
36
+ with gr.Blocks() as grammar_app:
37
+ gr.Markdown("<h1>Async Grammar Correction App</h1>")
38
+
39
+ with gr.Row():
40
+ input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
41
+ output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=4)
42
+
43
+ submit_button = gr.Button("Correct Grammar")
44
+
45
+ # When the button is clicked, run the correction process
46
+ submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
47
+
48
+ # Launch the app
49
+ if __name__ == "__main__":
50
+ grammar_app.launch()