File size: 5,691 Bytes
21d2052 dc8d63b 21d2052 dc8d63b 21d2052 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# -*- coding: utf-8 -*-
import numpy as np
import streamlit as st
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast
model_dir = "snoop2head/kogpt-conditional-2"
tokenizer = PreTrainedTokenizerFast.from_pretrained(
model_dir,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
)
@st.cache
def load_model(model_name):
model = AutoModelWithLMHead.from_pretrained(model_name)
return model
model = load_model(model_dir)
print("loaded model completed")
def find_nth(haystack, needle, n):
start = haystack.find(needle)
while start >= 0 and n > 1:
start = haystack.find(needle, start + len(needle))
n -= 1
return start
def infer(input_ids, max_length, temperature, top_k, top_p):
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=1,
)
return output_sequences
# prompts
st.title("์ผํ์์ ๋ฌ์ธ KoGPT์
๋๋ค ๐ฆ")
st.write("ํ
์คํธ๋ฅผ ์
๋ ฅํ๊ณ CTRL+Enter(CMD+Enter)์ ๋๋ฅด์ธ์ ๐ค")
# text and sidebars
default_value = "๋ฐ์๋ฏผ"
sent = st.text_area("Text", default_value, max_chars=4, height=275)
max_length = st.sidebar.slider("์์ฑ ๋ฌธ์ฅ ๊ธธ์ด๋ฅผ ์ ํํด์ฃผ์ธ์!", min_value=42, max_value=64)
temperature = st.sidebar.slider(
"Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05
)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
print("slider sidebars rendering completed")
# make input sentence
emotion_list = ["ํ๋ณต", "์ค๋ฆฝ", "๋ถ๋
ธ", "ํ์ค", "๋๋", "์ฌํ", "๊ณตํฌ"]
main_emotion = st.sidebar.radio("์ฃผ์ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list)
sub_emotion = st.sidebar.radio("๋ ๋ฒ์งธ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list)
print("radio sidebars rendering completed")
# create condition sentence
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1)
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1)
condition_sentence = f"{random_main_logit}๋งํผ {main_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. {random_sub_logit}๋งํผ {sub_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. "
condition_plus_input = condition_sentence + sent
print(condition_plus_input)
def infer_sentence(
condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2
):
encoded_prompt = tokenizer.encode(
condition_plus_input, add_special_tokens=False, return_tensors="pt"
)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
print(output_sequences)
# exclude item that contains "unk"
output_sequences = [
output_sequence
for output_sequence in output_sequences
if "unk" not in output_sequence
]
# choose item that length is longer than 1
output_sequences = [
output_sequence
for output_sequence in output_sequences
if len(output_sequence) > 1
]
generated_sequence = output_sequences[0]
print(generated_sequence)
# print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
# generated_sequences = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
print(text)
# Remove all text after the stop token
stop_token = tokenizer.pad_token
print(stop_token)
text = text[: text.find(stop_token) if stop_token else None]
print(text)
condition_index = find_nth(text, "๋ฌธ์ฅ์ด๋ค", 2)
text = text[condition_index + 5 :]
text = text.strip()
return text
def make_residual_conditional_samhaengshi(input_letter, condition_sentence):
# make letter string into
list_samhaengshi = []
# initializing text and index for iteration purpose
index = 0
# iterating over the input letter string
for index, letter_item in enumerate(input_letter):
# initializing the input_letter
if index == 0:
residual_text = letter_item
# print('residual_text:', residual_text)
# infer and add to the output
conditional_input = f"{condition_sentence} {residual_text}"
inferred_sentence = infer_sentence(conditional_input, tokenizer)
if index != 0:
# remove previous sentence from the output
print("inferred_sentence:", inferred_sentence)
inferred_sentence = inferred_sentence.replace(
list_samhaengshi[index - 1], ""
).strip()
else:
pass
list_samhaengshi.append(inferred_sentence)
# until the end of the input_letter, give the previous residual_text to the next iteration
if index < len(input_letter) - 1:
residual_sentence = list_samhaengshi[index]
next_letter = input_letter[index + 1]
residual_text = (
f"{residual_sentence} {next_letter}" # previous sentence + next letter
)
print("residual_text", residual_text)
elif index == len(input_letter) - 1: # end of the input_letter
# Concatenate strings in the list without intersection
return list_samhaengshi
return_text = make_residual_conditional_samhaengshi(
input_letter=sent, condition_sentence=condition_sentence
)
print(return_text)
st.write(return_text)
|