Spaces:
Runtime error
Runtime error
# ๋ชจ๋ธ ๋ก๋ฉ | |
import torch | |
from peft import PeftConfig, PeftModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model_name = "facebook/opt-350m" | |
adapter_model_name = 'msy127/opt-350m-aihubqa-130-dpo-adapter' | |
model = AutoModelForCausalLM.from_pretrained(base_model_name) | |
model = PeftModel.from_pretrained(model, adapter_model_name).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(adapter_model_name) | |
# ๋ํ ๋์ ํจ์ (history) - prompt ์๋ฆฌ์ history๊ฐ ๋ค์ด๊ฐ -> dialoGPT๋ ๋ชจ๋ธ ์ง์ด๋ฃ๊ธฐ ์ ์ ์ธ์ฝ๋ฉ์ ํ์๋๋ฐ OPENAI๋ ์ธ์ฝ๋ฉ์ ์ํ๋ค. | |
def predict(input, history): | |
history.append({"role": "user", "content": input}) | |
# ์ผ๋ฐ๋ชจ๋ธ | |
prompt = f"An AI tool that looks at the context and question separated by triple backquotes, finds the answer corresponding to the question in the context, and answers clearly.\n### Input: ```{input}```\n ### Output: " | |
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
outputs = model.generate(input_ids=inputs, max_length=256) | |
generated_text = tokenizer.decode(outputs[0]) | |
start_idx = len(prompt) + len('</s>') | |
stop_first_idx = generated_text.find("### Input:") # ์ฒซ ๋ฒ์งธ "### Input:"์ ์ฐพ์ต๋๋ค. | |
stop_idx = generated_text.find("### Input:", stop_first_idx + 1) # ์ฒซ ๋ฒ์งธ "### Input:" ์ดํ์ ๋ฌธ์์ด์์ ๋ค์ "### Input:"์ ์ฐพ์ต๋๋ค. | |
# print(start_idx , stop_idx) | |
# print(generated_text) | |
if stop_idx != -1: | |
response = generated_text[start_idx:stop_idx] # prompt ๋ค์ ์๋ ์๋กญ๊ฒ ์์ฑ๋ ํ ์คํธ๋ง ("### Input:" ์ ๊น์ง) ๊ฐ์ ธ์ต๋๋ค. | |
# ๋์ | |
history.append({"role": "assistant", "content": response}) | |
# messages = [(history[i]["content"], history[i+1]["content"]) for i in range(1, len(history), 2)] | |
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history) - 1, 2)] | |
return messages, history | |
# Gradio ์ธํฐํ์ด์ค ์ค์ | |
import gradio as gr | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot(label="ChatBot") | |
state = gr.State([ | |
{"role": "system", "content": "๋น์ ์ ์น์ ํ ์ธ๊ณต์ง๋ฅ ์ฑ๋ด์ ๋๋ค. ์ ๋ ฅ์ ๋ํด ์งง๊ณ ๊ฐ๊ฒฐํ๊ณ ์น์ ํ๊ฒ ๋๋ตํด์ฃผ์ธ์."}]) | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="์ฑ๋ด์๊ฒ ์๋ฌด๊ฑฐ๋ ๋ฌผ์ด๋ณด์ธ์").style(container=False) | |
# txt.submit(predict, [txt, state], [chatbot, state]) | |
txt.submit(predict, [txt, state], [chatbot, state]) | |
# demo.launch(debug=True, share=True) | |
demo.launch() | |
# from PIL import Image | |
# import gradio as gr | |
# interface = gr.Interface( | |
# fn=classify_image, | |
# inputs=gr.components.Image(type="pil", label="Upload an Image"), | |
# outputs="text", | |
# live=True | |
# ) | |
# interface.launch() |