test-nlp-study / app.py
kimmeoungjun's picture
Update app.py
abd383d
raw
history blame
1.68 kB
import torch
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
peft_model_id = "kimmeoungjun/qlora-koalpaca"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
def my_split(s, seps):
res = [s]
for sep in seps:
s, res = res, []
for seq in s:
res += seq.split(sep)
return res
def chat_base(input):
p = input
input_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
gen_tokens = model.generate(input_ids, do_sample=True, early_stopping=True, eos_token_id=2,)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
# print(gen_text)
result = gen_text[len(p):]
# print(">", result)
result = my_split(result, [']', '\n'])[1]
# print(">>", result)
# print(">>>", result)
return result
def chat(message):
history = gr.get_state() or []
print(history)
response = chat_base(message)
history.append((message, response))
gr.set_state(history)
html = "<div class='chatbot'>"
for user_msg, resp_msg in history:
html += f"<div class='user_msg'>{user_msg}</div>"
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return response
iface = gr.Interface(chat_base, gr.inputs.Textbox(label="물어보세요"), "text", allow_screenshot=False, allow_flagging=False)
iface.launch()