MS-YUN commited on
Commit
9b7601c
ยท
1 Parent(s): b1895c1

Add application file

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ๋ชจ๋ธ ๋กœ๋”ฉ
2
+ import torch
3
+ from peft import PeftConfig, PeftModel
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ base_model_name = "facebook/opt-350m"
9
+ adapter_model_name = 'msy127/opt-350m-aihubqa-130-dpo-adapter'
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(base_model_name)
12
+ model = PeftModel.from_pretrained(model, adapter_model_name).to(device)
13
+ tokenizer = AutoTokenizer.from_pretrained(adapter_model_name)
14
+
15
+ # ๋Œ€ํ™” ๋ˆ„์  ํ•จ์ˆ˜ (history) - prompt ์ž๋ฆฌ์— history๊ฐ€ ๋“ค์–ด๊ฐ -> dialoGPT๋Š” ๋ชจ๋ธ ์ง‘์–ด๋„ฃ๊ธฐ ์ „์— ์ธ์ฝ”๋”ฉ์„ ํ–ˆ์—ˆ๋Š”๋ฐ OPENAI๋Š” ์ธ์ฝ”๋”ฉ์„ ์•ˆํ•œ๋‹ค.
16
+
17
+ def predict(input, history):
18
+ history.append({"role": "user", "content": input})
19
+
20
+ # ์ผ๋ฐ˜๋ชจ๋ธ
21
+ prompt = f"An AI tool that looks at the context and question separated by triple backquotes, finds the answer corresponding to the question in the context, and answers clearly.\n### Input: ```{input}```\n ### Output: "
22
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
23
+ outputs = model.generate(input_ids=inputs, max_length=256)
24
+ generated_text = tokenizer.decode(outputs[0])
25
+ start_idx = len(prompt) + len('</s>')
26
+ stop_first_idx = generated_text.find("### Input:") # ์ฒซ ๋ฒˆ์งธ "### Input:"์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
27
+ stop_idx = generated_text.find("### Input:", stop_first_idx + 1) # ์ฒซ ๋ฒˆ์งธ "### Input:" ์ดํ›„์˜ ๋ฌธ์ž์—ด์—์„œ ๋‹ค์‹œ "### Input:"์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
28
+ # print(start_idx , stop_idx)
29
+ # print(generated_text)
30
+ if stop_idx != -1:
31
+ response = generated_text[start_idx:stop_idx] # prompt ๋’ค์— ์žˆ๋Š” ์ƒˆ๋กญ๊ฒŒ ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๋งŒ ("### Input:" ์ „๊นŒ์ง€) ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
32
+
33
+ # ๋ˆ„์ 
34
+ history.append({"role": "assistant", "content": response})
35
+ # messages = [(history[i]["content"], history[i+1]["content"]) for i in range(1, len(history), 2)]
36
+ messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history) - 1, 2)]
37
+
38
+ return messages, history
39
+
40
+
41
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
42
+ import gradio as gr
43
+ with gr.Blocks() as demo:
44
+ chatbot = gr.Chatbot(label="ChatBot")
45
+
46
+ state = gr.State([
47
+ {"role": "system", "content": "๋‹น์‹ ์€ ์นœ์ ˆํ•œ ์ธ๊ณต์ง€๋Šฅ ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค. ์ž…๋ ฅ์— ๋Œ€ํ•ด ์งง๊ณ  ๊ฐ„๊ฒฐํ•˜๊ณ  ์นœ์ ˆํ•˜๊ฒŒ ๋Œ€๋‹ตํ•ด์ฃผ์„ธ์š”."}])
48
+
49
+ with gr.Row():
50
+ txt = gr.Textbox(show_label=False, placeholder="์ฑ—๋ด‡์—๊ฒŒ ์•„๋ฌด๊ฑฐ๋‚˜ ๋ฌผ์–ด๋ณด์„ธ์š”").style(container=False)
51
+ # txt.submit(predict, [txt, state], [chatbot, state])
52
+
53
+ txt.submit(predict, [txt, state], [chatbot, state])
54
+
55
+ # demo.launch(debug=True, share=True)
56
+ demo.launch()
57
+
58
+
59
+ # from PIL import Image
60
+ # import gradio as gr
61
+ # interface = gr.Interface(
62
+ # fn=classify_image,
63
+ # inputs=gr.components.Image(type="pil", label="Upload an Image"),
64
+ # outputs="text",
65
+ # live=True
66
+ # )
67
+ # interface.launch()