File size: 2,067 Bytes
94a6c27 eb7e846 94a6c27 eb7e846 94a6c27 b97592a 94a6c27 eb7e846 94a6c27 c9d699b b97592a 94a6c27 eb7e846 94a6c27 b97592a 94a6c27 eb7e846 94a6c27 da4c6e9 94a6c27 da4c6e9 94a6c27 da4c6e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# -*-coding:utf-8-*-
import streamlit as st
# code from https://huggingface.co/kakaobrain/kogpt
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(
'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b', cache_dir='./model_dir/',
bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[PAD]', mask_token='[MASK]'
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForCausalLM.from_pretrained(
'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',cache_dir='./model_dir/',
pad_token_id=tokenizer.eos_token_id,
torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(device=device, non_blocking=True)
_ = model.eval()
print("Model loading done!")
def gpt(prompt):
with torch.no_grad():
tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=256)
generated = tokenizer.batch_decode(gen_tokens)[0]
return generated
#prompts
st.title("μ¬λ¬λΆλ€μ λ¬Έμ₯μ μμ±ν΄μ€λλ€. π€")
st.markdown("μΉ΄μΉ΄μ€ gpt μ¬μ©ν©λλ€.")
st.subheader("λͺκ°μ§ μμ : ")
example_1_str = "μ€λμ λ μ¨λ λ무 λλΆμλ€. λ΄μΌμ "
example_2_str = "μ°λ¦¬λ ν볡μ μΈμ λ κ°λ§νμ§λ§ νμ "
example_1 = st.button(example_1_str)
example_2 = st.button(example_2_str)
textbox = st.text_area('μ€λμ μλ¦λ€μμ ν₯ν΄ λ¬λ¦¬κ³ ', '',height=100, max_chars=500 )
button = st.button('μμ±:')
# output
st.subheader("κ²°κ³Όκ°: ")
if example_1:
with st.spinner('In progress.......'):
output_text = gpt(example_1_str)
st.markdown("\n"+output_text)
if example_2:
with st.spinner('In progress.......'):
output_text = gpt(example_2_str)
st.markdown("\n"+output_text)
if button:
with st.spinner('In progress.......'):
if textbox:
output_text = gpt(textbox)
else:
output_text = " "
st.markdown("\n" + output_text) |