Spaces:
Runtime error
Runtime error
File size: 2,453 Bytes
250f909 cc7fe5e c91acd0 250f909 cc7fe5e 250f909 cc7fe5e 250f909 cc7fe5e 250f909 cc7fe5e 250f909 cc7fe5e 250f909 cc7fe5e 250f909 cc7fe5e 250f909 cc7fe5e 250f909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import streamlit as st
from streamlit_chat import message
@st.cache(allow_output_mutation=True)
def get_pipe():
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "heegyu/ajoublue-gpt2-medium-dialog"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
# print("history:", history)
context = []
for i, text in enumerate(history):
context.append(f"{i % 2}: {text}</s>")
if len(context) > max_context:
context = context[-max_context:]
context = "".join(context) + f"{bot_id}: "
inputs = tokenizer(context, return_tensors="pt")
generation_args = dict(
max_new_tokens=128,
min_length=inputs["input_ids"].shape[1] + 5,
# no_repeat_ngram_size=4,
eos_token_id=2,
do_sample=True,
top_p=0.95,
temperature=1.35,
# repetition_penalty=1.0,
early_stopping=True
)
outputs = model.generate(**inputs, **generation_args)
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
print("Context:", tokenizer.decode(inputs["input_ids"][0]))
print("Response:", response)
response = response[len(context):].replace("</s>", "").replace("\n", "")
response = response.split("<s>")[0]
# print("Response:", response)
return response
st.title("ajoublue-gpt2-medium νκ΅μ΄ λν λͺ¨λΈ demo")
with st.spinner("loading model..."):
model, tokenizer = get_pipe()
if 'message_history' not in st.session_state:
st.session_state.message_history = []
history = st.session_state.message_history
# print(st.session_state.message_history)
for i, message_ in enumerate(st.session_state.message_history):
message(message_,is_user=i % 2 == 0, key=i) # display all the previous message
# placeholder = st.empty() # placeholder for latest message
input_ = st.text_input("μ무 λ§μ΄λ ν΄λ³΄μΈμ", value="")
if input_ is not None and len(input_) > 0:
if len(history) <= 1 or history[-2] != input_:
with st.spinner("λλ΅μ μμ±μ€μ
λλ€..."):
st.session_state.message_history.append(input_)
response = get_response(tokenizer, model, history)
st.session_state.message_history.append(response)
st.experimental_rerun() |