|
import streamlit as st |
|
import transformers |
|
import tensorflow |
|
|
|
|
|
from transformers import AutoTokenizer |
|
from transformers import TFAutoModelForSeq2SeqLM |
|
|
|
model_checkpoint = "Modfiededition/t5-base-fine-tuned-on-jfleg" |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
|
|
@st.cache |
|
def load_model(model_name): |
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return model |
|
|
|
model= load_model(model_checkpoint) |
|
|
|
default_value = "Write your text here!" |
|
|
|
st.title("Writing Assistant for you π¦") |
|
|
|
sent = st.text_area("Text", default_value, height = 275) |
|
|
|
inputs = tokenizer("Grammar: "+sent,return_tensors="tf") |
|
|
|
output_ids = model.generate(inputs["input_ids"]).numpy()[0][1:-1] |
|
|
|
generated_sequences = tokenizer.decode(output_ids) |
|
|
|
st.write(generated_sequences) |
|
|