Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from peft import PeftModel, PeftConfig | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
peft_model_id = "kimmeoungjun/qlora-koalpaca" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path) | |
model = PeftModel.from_pretrained(model, peft_model_id).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
def my_split(s, seps): | |
res = [s] | |
for sep in seps: | |
s, res = res, [] | |
for seq in s: | |
res += seq.split(sep) | |
return res | |
def chat_base(input): | |
p = input | |
input_ids = tokenizer(p, return_tensors="pt").input_ids.to(device) | |
gen_tokens = model.generate(input_ids, do_sample=True, early_stopping=True, do_sample=True, eos_token_id=2,) | |
gen_text = tokenizer.batch_decode(gen_tokens)[0] | |
# print(gen_text) | |
result = gen_text[len(p):] | |
# print(">", result) | |
result = my_split(result, [']', '\n'])[1] | |
# print(">>", result) | |
# print(">>>", result) | |
return result | |
def chat(message): | |
history = gr.get_state() or [] | |
print(history) | |
response = chat_base(message) | |
history.append((message, response)) | |
gr.set_state(history) | |
html = "<div class='chatbot'>" | |
for user_msg, resp_msg in history: | |
html += f"<div class='user_msg'>{user_msg}</div>" | |
html += f"<div class='resp_msg'>{resp_msg}</div>" | |
html += "</div>" | |
return response | |
iface = gr.Interface(chat_base, gr.inputs.Textbox(label="물어보세요"), "text", allow_screenshot=False, allow_flagging=False) | |
iface.launch() | |