Wootang01's picture
Update app.py
019744c
raw
history blame
1.38 kB
import streamlit as st
st.title("Grammar Corrector")
st.write("Paste or type text, submit and the machine will attempt to correct your text's grammar.")
default_text = "This should working"
sent = st.text_area("Text", default_text, height=40)
num_correct_options = st.number_input('Number of Correction Options', min_value=1, max_value=3, value=1, step=1)
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
def correct_grammar(input_text, num_correct_options=num_correct_options):
batch = tokenizer([input_text], truncation=True, padding = 'max_length', max_length = 64, return_tensors = 'pt').to(torch_device)
results = model.generate(**batch, max_length = 64, num_beams = 2, num_correct_options = num_correct_options, temperature = 1.5)
return results
results = correct_grammar(sent, num_correct_options)
generated_options = []
for generated_option_idx, generated_option in enumerate(results):
text = tokenizer.decode(generated_option, clean_up_tokenization_spaces = True, skip_special_tokens = True)
generated_options.append(text)
st.write(generated_options)