Spaces:
Runtime error
Runtime error
File size: 1,693 Bytes
f357513 4297655 f357513 4297655 f357513 06ee17b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
peft_model_id = "kimmeoungjun/qlora-koalpaca"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
def my_split(s, seps):
res = [s]
for sep in seps:
s, res = res, []
for seq in s:
res += seq.split(sep)
return res
def chat_base(input):
p = input
input_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
gen_tokens = model.generate(input_ids, do_sample=True, early_stopping=True, do_sample=True, eos_token_id=2,)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
# print(gen_text)
result = gen_text[len(p):]
# print(">", result)
result = my_split(result, [']', '\n'])[1]
# print(">>", result)
# print(">>>", result)
return result
def chat(message):
history = gr.get_state() or []
print(history)
response = chat_base(message)
history.append((message, response))
gr.set_state(history)
html = "<div class='chatbot'>"
for user_msg, resp_msg in history:
html += f"<div class='user_msg'>{user_msg}</div>"
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return response
iface = gr.Interface(chat_base, gr.inputs.Textbox(label="물어보세요"), "text", allow_screenshot=False, allow_flagging=False)
iface.launch()
|