Update app.py
Browse files
app.py
CHANGED
@@ -34,7 +34,15 @@ examples=[
|
|
34 |
]
|
35 |
|
36 |
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
40 |
for interaction in chatbot:
|
@@ -44,12 +52,13 @@ def predict(message, chatbot):
|
|
44 |
|
45 |
data = {
|
46 |
"inputs": input_prompt,
|
47 |
-
"parameters": {
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
|
54 |
|
55 |
partial_message = ""
|
@@ -84,8 +93,16 @@ def predict(message, chatbot):
|
|
84 |
continue
|
85 |
|
86 |
|
87 |
-
|
|
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
90 |
for interaction in chatbot:
|
91 |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
|
@@ -94,7 +111,13 @@ def predict_batch(message, chatbot):
|
|
94 |
|
95 |
data = {
|
96 |
"inputs": input_prompt,
|
97 |
-
"parameters": {
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
}
|
99 |
|
100 |
response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
|
@@ -114,13 +137,55 @@ def predict_batch(message, chatbot):
|
|
114 |
print(f"Request failed with status code {response.status_code}")
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
# Gradio Demo
|
118 |
with gr.Blocks() as demo:
|
119 |
|
120 |
with gr.Tab("Streaming"):
|
121 |
-
gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True)
|
122 |
|
123 |
with gr.Tab("Batch"):
|
124 |
-
gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True)
|
125 |
|
126 |
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|
|
|
34 |
]
|
35 |
|
36 |
|
37 |
+
# Stream text
|
38 |
+
def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
|
39 |
+
|
40 |
+
if system_prompt != "":
|
41 |
+
system_message = system_prompt
|
42 |
+
temperature = float(temperature)
|
43 |
+
if temperature < 1e-2:
|
44 |
+
temperature = 1e-2
|
45 |
+
top_p = float(top_p)
|
46 |
|
47 |
input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
48 |
for interaction in chatbot:
|
|
|
52 |
|
53 |
data = {
|
54 |
"inputs": input_prompt,
|
55 |
+
"parameters": {
|
56 |
+
"max_new_tokens":max_new_tokens,
|
57 |
+
"temperature"=temperature,
|
58 |
+
"top_p"=top_p,
|
59 |
+
"repetition_penalty"=repetition_penalty,
|
60 |
+
"do_sample":True,
|
61 |
+
},
|
62 |
response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
|
63 |
|
64 |
partial_message = ""
|
|
|
93 |
continue
|
94 |
|
95 |
|
96 |
+
# No Stream
|
97 |
+
def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
|
98 |
|
99 |
+
if system_prompt != "":
|
100 |
+
system_message = system_prompt
|
101 |
+
temperature = float(temperature)
|
102 |
+
if temperature < 1e-2:
|
103 |
+
temperature = 1e-2
|
104 |
+
top_p = float(top_p)
|
105 |
+
|
106 |
input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
107 |
for interaction in chatbot:
|
108 |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
|
|
|
111 |
|
112 |
data = {
|
113 |
"inputs": input_prompt,
|
114 |
+
"parameters": {
|
115 |
+
"max_new_tokens":max_new_tokens,
|
116 |
+
"temperature"=temperature,
|
117 |
+
"top_p"=top_p,
|
118 |
+
"repetition_penalty"=repetition_penalty,
|
119 |
+
"do_sample":True,
|
120 |
+
},
|
121 |
}
|
122 |
|
123 |
response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
|
|
|
137 |
print(f"Request failed with status code {response.status_code}")
|
138 |
|
139 |
|
140 |
+
|
141 |
+
additional_inputs=[
|
142 |
+
gr.Textbox("", label="Optional system prompt"),
|
143 |
+
gr.Slider(
|
144 |
+
label="Temperature",
|
145 |
+
value=0.9,
|
146 |
+
minimum=0.0,
|
147 |
+
maximum=1.0,
|
148 |
+
step=0.05,
|
149 |
+
interactive=True,
|
150 |
+
info="Higher values produce more diverse outputs",
|
151 |
+
),
|
152 |
+
gr.Slider(
|
153 |
+
label="Max new tokens",
|
154 |
+
value=256,
|
155 |
+
minimum=0,
|
156 |
+
maximum=4096,
|
157 |
+
step=64,
|
158 |
+
interactive=True,
|
159 |
+
info="The maximum numbers of new tokens",
|
160 |
+
),
|
161 |
+
gr.Slider(
|
162 |
+
label="Top-p (nucleus sampling)",
|
163 |
+
value=0.6,
|
164 |
+
minimum=0.0,
|
165 |
+
maximum=1,
|
166 |
+
step=0.05,
|
167 |
+
interactive=True,
|
168 |
+
info="Higher values sample more low-probability tokens",
|
169 |
+
),
|
170 |
+
gr.Slider(
|
171 |
+
label="Repetition penalty",
|
172 |
+
value=1.2,
|
173 |
+
minimum=1.0,
|
174 |
+
maximum=2.0,
|
175 |
+
step=0.05,
|
176 |
+
interactive=True,
|
177 |
+
info="Penalize repeated tokens",
|
178 |
+
)
|
179 |
+
]
|
180 |
+
|
181 |
+
|
182 |
# Gradio Demo
|
183 |
with gr.Blocks() as demo:
|
184 |
|
185 |
with gr.Tab("Streaming"):
|
186 |
+
gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
187 |
|
188 |
with gr.Tab("Batch"):
|
189 |
+
gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
190 |
|
191 |
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|