File size: 1,763 Bytes
c2f522f b5d0fef 19f79d5 ec4bfe0 9fc880b c2f522f b5d0fef c2f522f b5d0fef c2f522f b5d0fef c2f522f 625eebf b5d0fef c2f522f ec4bfe0 b5d0fef c2f522f b5d0fef c2f522f b5d0fef c2f522f b5d0fef c2f522f b5d0fef c2f522f b5d0fef c2f522f b5d0fef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Imports
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1")
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Grammar correction function
def correct_grammar(text):
# Tokenize input text with an increased max_length for handling larger input
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
# Generate corrected text with increased max_length and num_beams
outputs = model.generate(**inputs, max_length=2024, num_beams=5, early_stopping=True)
# Decode the output and return the corrected text
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
# Gradio interface function
def correct_grammar_interface(text):
corrected_text = correct_grammar(text)
return corrected_text
# Gradio app interface
with gr.Blocks() as grammar_app:
gr.Markdown("<h1>Grammar Correction App (up to 300 words)</h1>")
with gr.Row():
input_box = gr.Textbox(label="Input Text", placeholder="Enter text (up to 300 words)", lines=10)
output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=10)
submit_button = gr.Button("Correct Grammar")
# Bind the button click to the grammar correction function
submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
# Launch the app
if __name__ == "__main__":
grammar_app.launch()
|