File size: 2,017 Bytes
b334778 19f79d5 b334778 9fc880b 19f79d5 043440f fdeaa3e 4146933 9908a7a 19f79d5 fdeaa3e 19f79d5 9fc880b 9908a7a b334778 9908a7a 19f79d5 9908a7a 9fc880b b334778 9908a7a b334778 9908a7a b334778 9908a7a b334778 9908a7a b334778 9fc880b b334778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import gradio as gr
import torch
import asyncio
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load model and tokenizer
model_name = "hassaanik/grammar-correction-model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Use GPU if available, otherwise fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
# Use FP16 for faster inference on GPU
if torch.cuda.is_available():
model.half()
# Async grammar correction function with batch processing
async def correct_grammar_async(texts):
# Tokenize the batch of inputs and move it to the correct device (CPU/GPU)
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
# Asynchronous generation process
outputs = await asyncio.to_thread(model.generate, inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True)
# Decode outputs in parallel
corrected_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
return corrected_texts
# Gradio interface function to handle input and output
def correct_grammar_interface(text):
corrected_text = asyncio.run(correct_grammar_async([text]))[0] # Single input for now
return corrected_text
# Gradio Interface with async capabilities and batch input/output
with gr.Blocks() as grammar_app:
gr.Markdown("<h1>Fast Async Grammar Correction</h1>")
with gr.Row():
input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=4)
submit_button = gr.Button("Correct Grammar")
# When the button is clicked, run the correction process asynchronously
submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
# Launch the app
if __name__ == "__main__":
grammar_app.launch()
|