File size: 2,017 Bytes
b334778
19f79d5
b334778
 
9fc880b
19f79d5
043440f
fdeaa3e
4146933
9908a7a
19f79d5
 
fdeaa3e
19f79d5
 
 
9fc880b
9908a7a
 
 
 
b334778
9908a7a
 
19f79d5
9908a7a
 
 
9fc880b
b334778
 
9908a7a
b334778
 
9908a7a
b334778
9908a7a
b334778
 
 
 
 
 
 
9908a7a
b334778
 
 
9fc880b
b334778
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import torch
import asyncio
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load model and tokenizer
model_name = "hassaanik/grammar-correction-model"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Use GPU if available, otherwise fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

# Use FP16 for faster inference on GPU
if torch.cuda.is_available():
    model.half()

# Async grammar correction function with batch processing
async def correct_grammar_async(texts):
    # Tokenize the batch of inputs and move it to the correct device (CPU/GPU)
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

    # Asynchronous generation process
    outputs = await asyncio.to_thread(model.generate, inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True)
    
    # Decode outputs in parallel
    corrected_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    return corrected_texts

# Gradio interface function to handle input and output
def correct_grammar_interface(text):
    corrected_text = asyncio.run(correct_grammar_async([text]))[0]  # Single input for now
    return corrected_text

# Gradio Interface with async capabilities and batch input/output
with gr.Blocks() as grammar_app:
    gr.Markdown("<h1>Fast Async Grammar Correction</h1>")
    
    with gr.Row():
        input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
        output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=4)

    submit_button = gr.Button("Correct Grammar")
    
    # When the button is clicked, run the correction process asynchronously
    submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)

# Launch the app
if __name__ == "__main__":
    grammar_app.launch()