Mongolian-GPT2 / app.py
bayartsogt's picture
wordcloud + history addition
8028fa1
raw
history blame
2.71 kB
import re
import time
import streamlit as st
import pandas as pd
from wordcloud import WordCloud
from googletrans import Translator
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
from enums import MODEL_NAME, MESSAGES, DESCRIPTION
def iso2lang(iso):
return MESSAGES["iso"][iso]
def create_df_from_io(input, output):
return pd.DataFrame([[input, output, time.time()]], columns=["input", "output", "timestamp"])
def simple_clean(text):
return re.sub('[!@#$.,\n-?]', ' ', text.lower())
def load_tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@st.cache(allow_output_mutation=True)
def load_model():
return AutoModelWithLMHead.from_pretrained(MODEL_NAME)
def load_pipe():
model = load_model()
tokenizer = load_tokenizer()
return pipeline("text-generation", model=model, tokenizer=tokenizer)
# ---------------------------------------------------------------------- #
st.write(DESCRIPTION)
lang = st.radio('Хэл / Language', ('mn', 'en'), format_func=iso2lang)
translator = Translator()
if "df" not in st.session_state:
st.session_state.df = pd.DataFrame(columns=["input", "output", "timestamp"])
with st.spinner(MESSAGES["loading_text"][lang]):
pipe = load_pipe()
st.success(MESSAGES["success_model_load"][lang])
text = st.text_input(
MESSAGES["input_description"][lang], MESSAGES["input_default"][lang])
with st.spinner(MESSAGES["loading_text"][lang]):
if lang == "mn":
result = pipe(text)[0]['generated_text']
st.write(result)
elif lang == "en":
text = translator.translate(text, src='en', dest='mn').text
result = pipe(text)[0]['generated_text']
result_en = translator.translate(result, src='mn', dest='en').text
st.write(f"*Translated:* {result_en}")
st.write(f"> *Original:* {result}")
st.warning('Translation is done by [`googletrans`](https://github.com/ssut/py-googletrans). Please check out the usage. https://github.com/ssut/py-googletrans#note-on-library-usage')
st.session_state.df = st.session_state.df.append(create_df_from_io(text, result))
st.write("### WordCloud based on previous outputs")
with st.spinner(MESSAGES["loading_text"][lang]):
wordcloud_input = ""
for text in st.session_state.df.output.tolist():
wordcloud_input += simple_clean(text)
wordcloud = WordCloud(width = 800, height = 800,
background_color ='white',
min_font_size = 10).generate(wordcloud_input)
st.image(wordcloud.to_array())
st.write("### Түүх / History")
with st.spinner(MESSAGES["loading_text"][lang]):
st.table(st.session_state.df.sort_values(by="timestamp", ascending=False))