Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import subprocess
|
4 |
+
from threading import Thread
|
5 |
+
|
6 |
+
import torch
|
7 |
import spaces
|
8 |
import gradio as gr
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
10 |
+
|
11 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
12 |
+
|
13 |
+
MODEL_ID = "infly/OpenCoder-8B-Instruct"
|
14 |
+
CHAT_TEMPLATE = "ChatML"
|
15 |
+
MODEL_NAME = MODEL_ID.split("/")[-1]
|
16 |
+
CONTEXT_LENGTH = 1300
|
17 |
+
#EMOJI = os.environ.get("EMOJI")
|
18 |
+
#DESCRIPTION = os.environ.get("DESCRIPTION")
|
19 |
+
|
20 |
+
|
21 |
+
@spaces.GPU()
|
22 |
+
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
|
23 |
+
# Format history with a given chat template
|
24 |
+
if CHAT_TEMPLATE == "ChatML":
|
25 |
+
stop_tokens = ["<|endoftext|>", "<|im_end|>", "<|end_of_text|>", "<|eot_id|>", "assistant"]
|
26 |
+
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
|
27 |
+
for human, assistant in history:
|
28 |
+
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
|
29 |
+
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
|
30 |
+
elif CHAT_TEMPLATE == "Mistral Instruct":
|
31 |
+
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
|
32 |
+
instruction = '<s>[INST] ' + system_prompt
|
33 |
+
for human, assistant in history:
|
34 |
+
instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
|
35 |
+
instruction += ' ' + message + ' [/INST]'
|
36 |
+
else:
|
37 |
+
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
|
38 |
+
print(instruction)
|
39 |
+
|
40 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=90.0, skip_prompt=True, skip_special_tokens=True)
|
41 |
+
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True, max_length=CONTEXT_LENGTH)
|
42 |
+
input_ids, attention_mask = enc.input_ids, enc.attention_mask
|
43 |
+
|
44 |
+
if input_ids.shape[1] > CONTEXT_LENGTH:
|
45 |
+
input_ids = input_ids[:, -CONTEXT_LENGTH:]
|
46 |
+
|
47 |
+
generate_kwargs = dict(
|
48 |
+
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
|
49 |
+
streamer=streamer,
|
50 |
+
do_sample=True,
|
51 |
+
temperature=temperature,
|
52 |
+
max_new_tokens=max_new_tokens,
|
53 |
+
top_k=top_k,
|
54 |
+
repetition_penalty=repetition_penalty,
|
55 |
+
top_p=top_p
|
56 |
)
|
57 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
58 |
+
t.start()
|
59 |
+
outputs = []
|
60 |
+
for new_token in streamer:
|
61 |
+
outputs.append(new_token)
|
62 |
+
if new_token in stop_tokens:
|
63 |
+
break
|
64 |
+
yield "".join(outputs)
|
65 |
+
|
66 |
+
|
67 |
+
# Load model
|
68 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
69 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
70 |
+
model = AutoModelForCausalLM.from_pretrained(
|
71 |
+
MODEL_ID,
|
72 |
+
device_map="auto",
|
73 |
+
attn_implementation="flash_attention_2",
|
74 |
+
trust_remote_code=True
|
75 |
)
|
76 |
+
|
77 |
+
css = """
|
78 |
+
.message-row {
|
79 |
+
justify-content: space-evenly !important;
|
80 |
+
}
|
81 |
+
.message-bubble-border {
|
82 |
+
border-radius: 6px !important;
|
83 |
+
}
|
84 |
+
.message-buttons-bot, .message-buttons-user {
|
85 |
+
right: 10px !important;
|
86 |
+
left: auto !important;
|
87 |
+
bottom: 2px !important;
|
88 |
+
}
|
89 |
+
.dark.message-bubble-border {
|
90 |
+
border-color: #15172c !important;
|
91 |
+
}
|
92 |
+
.dark.user {
|
93 |
+
background: #10132c !important;
|
94 |
+
}
|
95 |
+
.dark.assistant.dark, .dark.pending.dark {
|
96 |
+
background: #020417 !important;
|
97 |
+
}
|
98 |
+
"""
|
99 |
+
|
100 |
+
# Create Gradio interface
|
101 |
+
gr.ChatInterface(
|
102 |
+
predict,
|
103 |
+
title=EMOJI + " " + MODEL_NAME,
|
104 |
+
description=DESCRIPTION,
|
105 |
+
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
|
106 |
+
additional_inputs=[
|
107 |
+
gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
|
108 |
+
gr.Slider(0, 1, 0.8, label="Temperature"),
|
109 |
+
gr.Slider(128, 4096, 512, label="Max new tokens"),
|
110 |
+
gr.Slider(1, 80, 40, label="Top K sampling"),
|
111 |
+
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
|
112 |
+
gr.Slider(0, 1, 0.95, label="Top P sampling"),
|
113 |
+
],
|
114 |
+
theme = gr.themes.Ocean(
|
115 |
+
secondary_hue="emerald"
|
116 |
+
),
|
117 |
+
css=css,
|
118 |
+
retry_btn="Retry",
|
119 |
+
undo_btn="Undo",
|
120 |
+
clear_btn="Clear",
|
121 |
+
submit_btn="Send",
|
122 |
+
chatbot=gr.Chatbot(
|
123 |
+
scale=1,
|
124 |
+
show_copy_button=True
|
125 |
+
)
|
126 |
+
).queue().launch()
|