import streamlit as st mirror_url = "https://news-generator.ai-research.id/" print("Streamlit Version: ", st.__version__) if st.__version__ != "1.9.0": st.warning(f"We move to: {mirror_url}") st.stop() import SessionState from mtranslate import translate from prompts import PROMPT_LIST import random import time import psutil import os import requests # st.set_page_config(page_title="Indonesian GPT-2") if "MIRROR_URL" in os.environ: mirror_url = os.environ["MIRROR_URL"] hf_auth_token = os.getenv("HF_AUTH_TOKEN", False) news_api_auth_token = os.getenv("NEWS_API_AUTH_TOKEN", False) MODELS = { "Indonesian Newspaper - Indonesian GPT-2 Medium": { "group": "Indonesian Newspaper", "name": "ai-research-id/gpt2-medium-newspaper", "description": "Newspaper Generator using Indonesian GPT-2 Medium.", "text_generator": None, "tokenizer": None }, } st.sidebar.markdown("""

""", unsafe_allow_html=True) st.sidebar.markdown(f""" ___

This is a collection of applications that generates sentences using Indonesian GPT-2 models!

Created by Indonesian NLP team @2021
GitHub | Project Report
A mirror of the application is available here

""", unsafe_allow_html=True) st.sidebar.markdown(""" ___ """, unsafe_allow_html=True) model_type = st.sidebar.selectbox('Model', (MODELS.keys())) # Disable the st.cache for this function due to issue on newer version of streamlit # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id}) def process(title: str, keywords: str, text: str, max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95, temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0, penalty_alpha = 0.6): # st.write("Cache miss: process") url = 'https://news-api.uncool.ai/api/text_generator/v1' # url = 'http://localhost:8000/api/text_generator/v1' headers = {'Authorization': 'Bearer ' + news_api_auth_token} data = { "title": title, "keywords": keywords, "text": text, "max_length": max_length, "do_sample": do_sample, "top_k": top_k, "top_p": top_p, "temperature": temperature, "max_time": max_time, "seed": seed, "repetition_penalty": repetition_penalty, "penalty_alpha": penalty_alpha } r = requests.post(url, headers=headers, data=data) if r.status_code == 200: result = r.json() return result else: return "Error: " + r.text st.title("Indonesian GPT-2 Applications") prompt_group_name = MODELS[model_type]["group"] st.header(prompt_group_name) description = f"This is a news generator using Indonesian GPT-2 Medium. We finetuned the pre-trained model with 1.4M " \ f"articles of the Indonesian online newspaper dataset." st.markdown(description) model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})" st.markdown(model_name) if prompt_group_name in ["Indonesian Newspaper"]: session_state = SessionState.get(prompt=None, prompt_box=None, text=None) ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys()) + ["Custom"] prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS) - 1) # Update prompt if session_state.prompt is None: session_state.prompt = prompt elif session_state.prompt is not None and (prompt != session_state.prompt): session_state.prompt = prompt session_state.prompt_box = None else: session_state.prompt = prompt # Update prompt box if session_state.prompt == "Custom": session_state.prompt_box = "" session_state.title = "" session_state.keywords = "" else: if session_state.prompt is not None and session_state.prompt_box is None: choice = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt]) session_state.title = choice["title"] session_state.keywords = choice["keywords"] session_state.prompt_box = choice["text"] session_state.title = st.text_input("Title", session_state.title) session_state.keywords = st.text_input("Keywords", session_state.keywords) session_state.text = st.text_area("Prompt", session_state.prompt_box) max_length = st.sidebar.number_input( "Maximum length", value=250, max_value=512, help="The maximum length of the sequence to be generated." ) decoding_methods = st.sidebar.radio( "Set the decoding methods:", key="decoding", options=["Beam Search", "Sampling", "Contrastive Search"], index=2 ) temperature = st.sidebar.slider( "Temperature", value=0.4, min_value=0.0, max_value=2.0 ) top_k = 30 top_p = 0.95 repetition_penalty = 0.0 penalty_alpha = None if decoding_methods == "Beam Search": do_sample = False elif decoding_methods == "Sampling": do_sample = True top_k = st.sidebar.number_input( "Top k", value=top_k, help="The number of highest probability vocabulary tokens to keep for top-k-filtering." ) top_p = st.sidebar.number_input( "Top p", value=top_p, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher " "are kept for generation." ) else: do_sample = False repetition_penalty = 1.1 penalty_alpha = st.sidebar.number_input( "Penalty alpha", value=0.6, help="The penalty alpha for contrastive search." ) top_k = st.sidebar.number_input( "Top k", value=4, help="The number of highest probability vocabulary tokens to keep for top-k-filtering." ) seed = st.sidebar.number_input( "Random Seed", value=25, help="The number used to initialize a pseudorandom number generator" ) if decoding_methods != "Contrastive Search": automatic_repetition_penalty = st.sidebar.checkbox( "Automatic Repetition Penalty", value=True ) if not automatic_repetition_penalty: repetition_penalty = st.sidebar.slider( "Repetition Penalty", value=1.0, min_value=1.0, max_value=2.0 ) # st.write(f"Generator: {MODELS}'") if st.button("Run"): with st.spinner(text="Getting results..."): memory = psutil.virtual_memory() # st.subheader("Result") time_start = time.time() # text_generator = MODELS[model_type]["text_generator"] result = process(title=session_state.title, keywords=session_state.keywords, text=session_state.text, max_length=int(max_length), temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha, top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty) time_end = time.time() time_diff = time_end - time_start # result = result[0]["generated_text"] title = f"### {session_state.title}" tldr = f"*{result['description'].strip()}*" caption = f"*Photo Caption: {result['caption'].strip()}*" if result['caption'].strip() != "" else "" st.markdown(title) st.markdown(tldr) st.markdown(result["generated_text"].replace("\n", " \n")) st.markdown(caption.replace("\n", " \n")) st.markdown("**Translation**") translation = translate(result["generated_text"], "en", "id") st.write(translation.replace("\n", " \n")) # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*") info = f""" *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB* *Text generated in {time_diff:.5} seconds* """ st.write(info) # Reset state session_state.prompt = None session_state.prompt_box = None