kimmeoungjun commited on
Commit
f357513
·
1 Parent(s): 06ee17b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -11
app.py CHANGED
@@ -1,15 +1,50 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
 
5
 
6
- def predict(image):
7
- predictions = pipeline(image)
8
- return {p["label"]: p["score"] for p in predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- gr.Interface(
11
- predict,
12
- inputs=gr.inputs.Image(label="Upload hot dog candidate", type="filepath"),
13
- outputs=gr.outputs.Label(num_top_classes=2),
14
- title="Hot Dog? Or Not?",
15
- ).launch()
 
1
+ import torch
2
  import gradio as gr
 
3
 
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ from peft import PeftModel, PeftConfig
6
 
7
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
+ peft_model_id = "kimmeoungjun/qlora-koalpaca"
9
+ config = PeftConfig.from_pretrained(peft_model_id)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
11
+ model = PeftModel.from_pretrained(model, peft_model_id).to(device)
12
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
+
14
+ def my_split(s, seps):
15
+ res = [s]
16
+ for sep in seps:
17
+ s, res = res, []
18
+ for seq in s:
19
+ res += seq.split(sep)
20
+ return res
21
+
22
+ def chat_base(input):
23
+ p = input
24
+ input_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
25
+ gen_tokens = model.generate(input_ids, do_sample=True, early_stopping=True, do_sample=True, eos_token_id=2,)
26
+ gen_text = tokenizer.batch_decode(gen_tokens)[0]
27
+ # print(gen_text)
28
+ result = gen_text[len(p):]
29
+ # print(">", result)
30
+ result = my_split(result, [']', '\n'])[1]
31
+ # print(">>", result)
32
+ # print(">>>", result)
33
+ return result
34
+
35
+ def chat(message):
36
+ history = gr.get_state() or []
37
+ print(history)
38
+ response = chat_base(message)
39
+ history.append((message, response))
40
+ gr.set_state(history)
41
+ html = "<div class='chatbot'>"
42
+ for user_msg, resp_msg in history:
43
+ html += f"<div class='user_msg'>{user_msg}</div>"
44
+ html += f"<div class='resp_msg'>{resp_msg}</div>"
45
+ html += "</div>"
46
+ return response
47
+
48
+ iface = gr.Interface(chat_base, gr.inputs.Textbox(label="물어보세요"), "text", allow_screenshot=False, allow_flagging=False)
49
+ iface.launch()
50