File size: 1,763 Bytes
c2f522f
b5d0fef
19f79d5
ec4bfe0
9fc880b
c2f522f
 
 
b5d0fef
c2f522f
b5d0fef
c2f522f
b5d0fef
c2f522f
 
 
 
 
 
625eebf
b5d0fef
c2f522f
ec4bfe0
b5d0fef
 
c2f522f
b5d0fef
c2f522f
b5d0fef
 
c2f522f
b5d0fef
c2f522f
b5d0fef
 
c2f522f
 
b5d0fef
 
 
c2f522f
b5d0fef
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Imports
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1")

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Grammar correction function
def correct_grammar(text):
    # Tokenize input text with an increased max_length for handling larger input
    inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
    
    # Generate corrected text with increased max_length and num_beams
    outputs = model.generate(**inputs, max_length=2024, num_beams=5, early_stopping=True)
    
    # Decode the output and return the corrected text
    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

# Gradio interface function
def correct_grammar_interface(text):
    corrected_text = correct_grammar(text)
    return corrected_text

# Gradio app interface
with gr.Blocks() as grammar_app:
    gr.Markdown("<h1>Grammar Correction App (up to 300 words)</h1>")
    
    with gr.Row():
        input_box = gr.Textbox(label="Input Text", placeholder="Enter text (up to 300 words)", lines=10)
        output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=10)

    submit_button = gr.Button("Correct Grammar")
    
    # Bind the button click to the grammar correction function
    submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)

# Launch the app
if __name__ == "__main__":
    grammar_app.launch()