Spaces:
Sleeping
Sleeping
File size: 1,671 Bytes
9f814ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import streamlit as st
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import time
@st.cache_resource
def load_model():
model = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics', cache_dir='cache/')
return model
@st.cache_resource
def load_tokenizer():
tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics', legacy=False, cache_dir='cache/')
return tokenizer
def initialize_app():
st.set_page_config(
page_title="Dia-critic",
page_icon="public/favicon.ico",
menu_items={
"About": "### Contact\n ✉️[email protected]",
},
)
st.title("🖋️Dia-critic")
st.caption("Made with :heart: by NEBO Technologies")
def generate_text(text):
model = load_model()
tokenizer = load_tokenizer()
inputs = tokenizer(text, max_length=256, truncation=True, return_tensors="pt")
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return output
def main():
initialize_app()
input_text = st.text_area("Introduceți textul mai jos")
st.write(f'{len(input_text)} caractere.')
if st.button("Corectează"):
if input_text != "":
res = ''
with st.spinner('Sarcină în desfășurare...'):
# start task
res = generate_text(input_text)
with st.container(border=True):
st.markdown(res)
else:
st.warning("Câmpul este gol!")
if __name__ == "__main__":
main() |