Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,120 +2,299 @@ import argparse
|
|
2 |
import datetime
|
3 |
import hashlib
|
4 |
import json
|
|
|
5 |
import os
|
6 |
import sys
|
7 |
import time
|
8 |
-
import warnings
|
9 |
|
10 |
import gradio as gr
|
11 |
-
import spaces
|
12 |
import torch
|
13 |
-
|
14 |
-
from
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def clear_conv(conv):
|
24 |
conv.messages = []
|
25 |
return conv
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
headers = {"User-Agent": "LLaVA Client"}
|
31 |
-
|
32 |
no_change_btn = gr.Button()
|
33 |
enable_btn = gr.Button(interactive=True)
|
34 |
disable_btn = gr.Button(interactive=False)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
}
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
global tokenizer, model, image_processor, context_len
|
72 |
-
with warnings.catch_warnings(record=True) as w:
|
73 |
-
warnings.simplefilter("always")
|
74 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
|
75 |
-
for warning in w:
|
76 |
-
if "vision" not in str(warning.message).lower():
|
77 |
-
print(warning.message)
|
78 |
-
model.config.tokenizer_model_max_length = 2048 * 2
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def get_conv_log_filename():
|
82 |
t = datetime.datetime.now()
|
83 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
|
|
84 |
return name
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
|
88 |
-
models = [
|
89 |
-
'AIML-TUDA/LlavaGuard-7B',
|
90 |
-
'AIML-TUDA/LlavaGuard-v1.1-7B-hf',
|
91 |
-
'AIML-TUDA/LlavaGuard-13B',
|
92 |
-
'AIML-TUDA/LlavaGuard-v1.1-13B-hf']
|
93 |
-
return models
|
94 |
-
|
95 |
-
|
96 |
get_window_url_params = """
|
97 |
function() {
|
98 |
const params = new URLSearchParams(window.location.search);
|
99 |
url_params = Object.fromEntries(params);
|
100 |
console.log(url_params);
|
101 |
return url_params;
|
102 |
-
|
103 |
"""
|
104 |
|
105 |
-
|
106 |
def load_demo(url_params, request: gr.Request):
|
107 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
108 |
-
|
|
|
109 |
dropdown_update = gr.Dropdown(visible=True)
|
110 |
if "model" in url_params:
|
111 |
model = url_params["model"]
|
112 |
if model in models:
|
113 |
dropdown_update = gr.Dropdown(value=model, visible=True)
|
|
|
114 |
|
115 |
state = default_conversation.copy()
|
116 |
return state, dropdown_update
|
117 |
|
118 |
-
|
119 |
def load_demo_refresh_model_list(request: gr.Request):
|
120 |
logger.info(f"load_demo. ip: {request.client.host}")
|
121 |
models = get_model_list()
|
@@ -126,7 +305,6 @@ def load_demo_refresh_model_list(request: gr.Request):
|
|
126 |
)
|
127 |
return state, dropdown_update
|
128 |
|
129 |
-
|
130 |
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
131 |
with open(get_conv_log_filename(), "a") as fout:
|
132 |
data = {
|
@@ -138,25 +316,21 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
|
138 |
}
|
139 |
fout.write(json.dumps(data) + "\n")
|
140 |
|
141 |
-
|
142 |
def upvote_last_response(state, model_selector, request: gr.Request):
|
143 |
logger.info(f"upvote. ip: {request.client.host}")
|
144 |
vote_last_response(state, "upvote", model_selector, request)
|
145 |
return ("",) + (disable_btn,) * 3
|
146 |
|
147 |
-
|
148 |
def downvote_last_response(state, model_selector, request: gr.Request):
|
149 |
logger.info(f"downvote. ip: {request.client.host}")
|
150 |
vote_last_response(state, "downvote", model_selector, request)
|
151 |
return ("",) + (disable_btn,) * 3
|
152 |
|
153 |
-
|
154 |
def flag_last_response(state, model_selector, request: gr.Request):
|
155 |
logger.info(f"flag. ip: {request.client.host}")
|
156 |
vote_last_response(state, "flag", model_selector, request)
|
157 |
return ("",) + (disable_btn,) * 3
|
158 |
|
159 |
-
|
160 |
def regenerate(state, image_process_mode, request: gr.Request):
|
161 |
logger.info(f"regenerate. ip: {request.client.host}")
|
162 |
state.messages[-1][-1] = None
|
@@ -166,30 +340,20 @@ def regenerate(state, image_process_mode, request: gr.Request):
|
|
166 |
state.skip_next = False
|
167 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
168 |
|
169 |
-
|
170 |
def clear_history(request: gr.Request):
|
171 |
logger.info(f"clear_history. ip: {request.client.host}")
|
172 |
state = default_conversation.copy()
|
173 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
174 |
|
175 |
-
|
176 |
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
177 |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
178 |
if len(text) <= 0 or image is None:
|
179 |
state.skip_next = True
|
180 |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
181 |
-
if args.moderate:
|
182 |
-
flagged = violates_moderation(text)
|
183 |
-
if flagged:
|
184 |
-
state.skip_next = True
|
185 |
-
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
186 |
-
no_change_btn,) * 5
|
187 |
|
188 |
text = wrap_taxonomy(text)
|
189 |
if image is not None:
|
190 |
-
text = text # Hard cut-off for images
|
191 |
if '<image>' not in text:
|
192 |
-
# text = '<Image><image></Image>' + text
|
193 |
text = text + '\n<image>'
|
194 |
text = (text, image, image_process_mode)
|
195 |
state = default_conversation.copy()
|
@@ -199,83 +363,50 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
199 |
state.skip_next = False
|
200 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
201 |
|
202 |
-
|
203 |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
204 |
start_tstamp = time.time()
|
205 |
-
|
206 |
-
|
207 |
if state.skip_next:
|
208 |
# This generate call is skipped due to invalid inputs
|
209 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
210 |
return
|
211 |
|
212 |
-
|
213 |
-
# First round of conversation
|
214 |
-
if "llava" in model_name.lower():
|
215 |
-
if 'llama-2' in model_name.lower():
|
216 |
-
template_name = "llava_llama_2"
|
217 |
-
elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
|
218 |
-
if 'orca' in model_name.lower():
|
219 |
-
template_name = "mistral_orca"
|
220 |
-
elif 'hermes' in model_name.lower():
|
221 |
-
template_name = "chatml_direct"
|
222 |
-
else:
|
223 |
-
template_name = "mistral_instruct"
|
224 |
-
elif 'llava-v1.6-34b' in model_name.lower():
|
225 |
-
template_name = "chatml_direct"
|
226 |
-
elif "v1" in model_name.lower():
|
227 |
-
if 'mmtag' in model_name.lower():
|
228 |
-
template_name = "v1_mmtag"
|
229 |
-
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
230 |
-
template_name = "v1_mmtag"
|
231 |
-
else:
|
232 |
-
template_name = "llava_v1"
|
233 |
-
elif "mpt" in model_name.lower():
|
234 |
-
template_name = "mpt"
|
235 |
-
else:
|
236 |
-
if 'mmtag' in model_name.lower():
|
237 |
-
template_name = "v0_mmtag"
|
238 |
-
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
239 |
-
template_name = "v0_mmtag"
|
240 |
-
else:
|
241 |
-
template_name = "llava_v0"
|
242 |
-
elif "mpt" in model_name:
|
243 |
-
template_name = "mpt_text"
|
244 |
-
elif "llama-2" in model_name:
|
245 |
-
template_name = "llama_2"
|
246 |
-
else:
|
247 |
-
template_name = "vicuna_v1"
|
248 |
-
new_state = conv_templates[template_name].copy()
|
249 |
-
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
250 |
-
new_state.append_message(new_state.roles[1], None)
|
251 |
-
state = new_state
|
252 |
-
|
253 |
-
# Construct prompt
|
254 |
prompt = state.get_prompt()
|
255 |
-
|
256 |
all_images = state.get_images(return_pil=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
258 |
-
for image,
|
259 |
t = datetime.datetime.now()
|
260 |
-
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{
|
261 |
if not os.path.isfile(filename):
|
262 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
263 |
image.save(filename)
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
267 |
state.messages[-1][-1] = output
|
268 |
|
269 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
270 |
|
271 |
finish_tstamp = time.time()
|
272 |
-
logger.info(f"{
|
273 |
|
274 |
with open(get_conv_log_filename(), "a") as fout:
|
275 |
data = {
|
276 |
"tstamp": round(finish_tstamp, 4),
|
277 |
"type": "chat",
|
278 |
-
"model":
|
279 |
"start": round(start_tstamp, 4),
|
280 |
"finish": round(finish_tstamp, 4),
|
281 |
"state": state.dict(),
|
@@ -284,41 +415,38 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
284 |
}
|
285 |
fout.write(json.dumps(data) + "\n")
|
286 |
|
287 |
-
|
288 |
-
title_markdown =
|
289 |
# LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
|
290 |
[[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
|
291 |
[[Code](https://github.com/ml-research/LlavaGuard)]
|
292 |
[[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
|
293 |
[[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
|
294 |
[[LavaGuard](https://arxiv.org/abs/2406.05113)]
|
295 |
-
"""
|
296 |
|
297 |
-
tos_markdown =
|
298 |
### Terms of use
|
299 |
By using this service, users are required to agree to the following terms:
|
300 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
301 |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
302 |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
303 |
-
"""
|
304 |
|
305 |
-
learn_more_markdown =
|
306 |
### License
|
307 |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
308 |
-
"""
|
309 |
|
310 |
block_css = """
|
311 |
-
|
312 |
#buttons button {
|
313 |
min-width: min(120px,100%);
|
314 |
}
|
315 |
-
|
316 |
"""
|
317 |
|
318 |
-
taxonomies = ["Default", "Modified w/ O1 non-violating", "Default message 3"]
|
319 |
-
|
320 |
-
|
321 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
|
|
|
322 |
with gr.Accordion("Safety Risk Taxonomy", open=False) as accordion:
|
323 |
textbox = gr.Textbox(
|
324 |
label="Safety Risk Taxonomy",
|
@@ -327,6 +455,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
327 |
container=True,
|
328 |
value=default_taxonomy,
|
329 |
lines=50)
|
|
|
330 |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
|
331 |
state = gr.State()
|
332 |
|
@@ -351,16 +480,17 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
351 |
|
352 |
if cur_dir is None:
|
353 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
354 |
gr.Examples(examples=[
|
355 |
-
[f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6)
|
356 |
], inputs=imagebox)
|
357 |
|
358 |
with gr.Accordion("Parameters", open=False) as parameter_row:
|
359 |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
|
360 |
-
|
361 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P"
|
362 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
|
363 |
-
|
364 |
|
365 |
with gr.Column(scale=8):
|
366 |
chatbot = gr.Chatbot(
|
@@ -378,7 +508,6 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
378 |
upvote_btn = gr.Button(value="π Upvote", interactive=False)
|
379 |
downvote_btn = gr.Button(value="π Downvote", interactive=False)
|
380 |
flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
|
381 |
-
# stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False)
|
382 |
regenerate_btn = gr.Button(value="π Regenerate", interactive=False)
|
383 |
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
|
384 |
|
@@ -389,26 +518,30 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
389 |
|
390 |
# Register listeners
|
391 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
|
|
392 |
upvote_btn.click(
|
393 |
upvote_last_response,
|
394 |
[state, model_selector],
|
395 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
396 |
)
|
|
|
397 |
downvote_btn.click(
|
398 |
downvote_last_response,
|
399 |
[state, model_selector],
|
400 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
401 |
)
|
|
|
402 |
flag_btn.click(
|
403 |
flag_last_response,
|
404 |
[state, model_selector],
|
405 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
406 |
)
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
412 |
|
413 |
regenerate_btn.click(
|
414 |
regenerate,
|
@@ -451,22 +584,12 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
451 |
concurrency_limit=concurrency_count
|
452 |
)
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
)
|
461 |
-
elif args.model_list_mode == "reload":
|
462 |
-
demo.load(
|
463 |
-
load_demo_refresh_model_list,
|
464 |
-
None,
|
465 |
-
[state, model_selector],
|
466 |
-
queue=False
|
467 |
-
)
|
468 |
-
else:
|
469 |
-
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
470 |
|
471 |
return demo
|
472 |
|
@@ -475,51 +598,37 @@ if __name__ == "__main__":
|
|
475 |
parser = argparse.ArgumentParser()
|
476 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
477 |
parser.add_argument("--port", type=int)
|
478 |
-
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
|
479 |
parser.add_argument("--concurrency-count", type=int, default=5)
|
480 |
-
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
|
481 |
parser.add_argument("--share", action="store_true")
|
482 |
parser.add_argument("--moderate", action="store_true")
|
483 |
parser.add_argument("--embed", action="store_true")
|
484 |
args = parser.parse_args()
|
485 |
-
models = []
|
486 |
-
|
487 |
-
title_markdown += """
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
|
|
498 |
api_key = os.getenv("token")
|
499 |
-
|
500 |
-
|
501 |
-
models = get_model_list()
|
502 |
-
bits = int(os.getenv("bits", 16))
|
503 |
-
model = os.getenv("model", models[0])
|
504 |
-
available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
|
505 |
-
model_path, model_name = model, model.split("/")[0]
|
506 |
if api_key:
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
model.config.tokenizer_model_max_length = 2048 * 2
|
519 |
-
|
520 |
-
exit_status = 0
|
521 |
try:
|
522 |
-
demo = build_demo(embed_mode=
|
523 |
demo.queue(
|
524 |
status_update_rate=10,
|
525 |
api_open=False
|
@@ -528,9 +637,6 @@ Set the environment variable `model` to change the model:
|
|
528 |
server_port=args.port,
|
529 |
share=args.share
|
530 |
)
|
531 |
-
|
532 |
except Exception as e:
|
533 |
-
|
534 |
-
|
535 |
-
finally:
|
536 |
-
sys.exit(exit_status)
|
|
|
2 |
import datetime
|
3 |
import hashlib
|
4 |
import json
|
5 |
+
import logging
|
6 |
import os
|
7 |
import sys
|
8 |
import time
|
|
|
9 |
|
10 |
import gradio as gr
|
|
|
11 |
import torch
|
12 |
+
from PIL import Image
|
13 |
+
from transformers import (
|
14 |
+
AutoModelForCausalLM,
|
15 |
+
AutoProcessor,
|
16 |
+
AutoTokenizer,
|
17 |
+
Qwen2_5_VLForConditionalGeneration
|
18 |
+
)
|
19 |
+
|
20 |
+
from taxonomy import policy_v1
|
21 |
+
|
22 |
+
# Set up logging
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.INFO,
|
25 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
26 |
+
handlers=[
|
27 |
+
logging.FileHandler("gradio_web_server.log"),
|
28 |
+
logging.StreamHandler()
|
29 |
+
]
|
30 |
+
)
|
31 |
+
logger = logging.getLogger("gradio_web_server")
|
32 |
+
|
33 |
+
# Constants
|
34 |
+
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
35 |
+
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
|
36 |
+
|
37 |
+
default_taxonomy = policy_v1
|
38 |
+
|
39 |
+
class Conversation:
|
40 |
+
def __init__(self):
|
41 |
+
self.messages = []
|
42 |
+
self.roles = ["user", "assistant"]
|
43 |
+
self.offset = 0
|
44 |
+
self.skip_next = False
|
45 |
+
|
46 |
+
def append_message(self, role, message):
|
47 |
+
self.messages.append([role, message])
|
48 |
+
|
49 |
+
def to_gradio_chatbot(self):
|
50 |
+
ret = []
|
51 |
+
for role, message in self.messages:
|
52 |
+
if message is None:
|
53 |
+
continue
|
54 |
+
if role == self.roles[0]:
|
55 |
+
if isinstance(message, tuple):
|
56 |
+
ret.append([self.render_user_message(message[0]), None])
|
57 |
+
else:
|
58 |
+
ret.append([self.render_user_message(message), None])
|
59 |
+
elif role == self.roles[1]:
|
60 |
+
if ret[-1][1] is None:
|
61 |
+
ret[-1][1] = message
|
62 |
+
else:
|
63 |
+
ret.append([None, message])
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Invalid role: {role}")
|
66 |
+
return ret
|
67 |
+
|
68 |
+
def render_user_message(self, message):
|
69 |
+
if "<image>" in message:
|
70 |
+
return message.replace("<image>", "")
|
71 |
+
return message
|
72 |
+
|
73 |
+
def dict(self):
|
74 |
+
return {
|
75 |
+
"messages": self.messages,
|
76 |
+
"roles": self.roles,
|
77 |
+
"offset": self.offset,
|
78 |
+
"skip_next": self.skip_next,
|
79 |
+
}
|
80 |
+
|
81 |
+
def get_prompt(self):
|
82 |
+
prompt = ""
|
83 |
+
for role, message in self.messages:
|
84 |
+
if message is None:
|
85 |
+
continue
|
86 |
+
if isinstance(message, tuple):
|
87 |
+
message = message[0]
|
88 |
+
if role == self.roles[0]:
|
89 |
+
prompt += f"USER: {message}\n"
|
90 |
+
else:
|
91 |
+
prompt += f"ASSISTANT: {message}\n"
|
92 |
+
return prompt + "ASSISTANT: "
|
93 |
+
|
94 |
+
def get_images(self, return_pil=False):
|
95 |
+
images = []
|
96 |
+
for role, message in self.messages:
|
97 |
+
if isinstance(message, tuple) and len(message) > 1:
|
98 |
+
if isinstance(message[1], Image.Image):
|
99 |
+
images.append(message[1] if return_pil else message[1])
|
100 |
+
return images
|
101 |
+
|
102 |
+
def copy(self):
|
103 |
+
new_conv = Conversation()
|
104 |
+
new_conv.messages = self.messages.copy()
|
105 |
+
new_conv.roles = self.roles.copy()
|
106 |
+
new_conv.offset = self.offset
|
107 |
+
new_conv.skip_next = self.skip_next
|
108 |
+
return new_conv
|
109 |
+
|
110 |
+
default_conversation = Conversation()
|
111 |
+
|
112 |
+
# Model and processor storage
|
113 |
+
tokenizer = None
|
114 |
+
model = None
|
115 |
+
processor = None
|
116 |
+
context_len = 2048
|
117 |
+
|
118 |
+
# Helper functions
|
119 |
def clear_conv(conv):
|
120 |
conv.messages = []
|
121 |
return conv
|
122 |
|
123 |
+
def wrap_taxonomy(text):
|
124 |
+
"""Wraps user input with taxonomy if not already present"""
|
125 |
+
if policy_v1 not in text:
|
126 |
+
return policy_v1 + "\n\n" + text
|
127 |
+
return text
|
128 |
|
129 |
+
# UI component states
|
|
|
|
|
|
|
130 |
no_change_btn = gr.Button()
|
131 |
enable_btn = gr.Button(interactive=True)
|
132 |
disable_btn = gr.Button(interactive=False)
|
133 |
|
134 |
+
# Model loading function
|
135 |
+
def load_model(model_path):
|
136 |
+
global tokenizer, model, processor, context_len
|
137 |
+
|
138 |
+
logger.info(f"Loading model: {model_path}")
|
139 |
+
|
140 |
+
try:
|
141 |
+
# Check if it's a Qwen model
|
142 |
+
if "qwenguard" in model_path.lower():
|
143 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
144 |
+
model_path,
|
145 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
146 |
+
device_map="auto" if torch.cuda.is_available() else None
|
147 |
+
)
|
148 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
149 |
+
tokenizer = processor.tokenizer
|
150 |
+
|
151 |
+
# Otherwise assume it's a LlavaGuard model
|
152 |
+
else:
|
153 |
+
model = AutoModelForCausalLM.from_pretrained(
|
154 |
+
model_path,
|
155 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
156 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
157 |
+
trust_remote_code=True
|
158 |
+
)
|
159 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
160 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
161 |
+
|
162 |
+
context_len = getattr(model.config, "max_position_embeddings", 2048)
|
163 |
+
logger.info(f"Model {model_path} loaded successfully")
|
164 |
+
return True
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
logger.error(f"Error loading model {model_path}: {str(e)}")
|
168 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
def get_model_list():
|
171 |
+
models = [
|
172 |
+
'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
|
173 |
+
'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
|
174 |
+
'AIML-TUDA/QwenGuard-v1.2-7B',
|
175 |
+
'AIML-TUDA/QwenGuard-v1.2-3B'
|
176 |
+
]
|
177 |
+
return models
|
178 |
|
179 |
def get_conv_log_filename():
|
180 |
t = datetime.datetime.now()
|
181 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
182 |
+
os.makedirs(os.path.dirname(name), exist_ok=True)
|
183 |
return name
|
184 |
|
185 |
+
# Inference function
|
186 |
+
def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
187 |
+
global model, tokenizer, processor
|
188 |
+
|
189 |
+
if model is None or processor is None:
|
190 |
+
return "Model not loaded. Please select a model first."
|
191 |
+
|
192 |
+
try:
|
193 |
+
# Check if it's a Qwen model
|
194 |
+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
195 |
+
# Format for Qwen models
|
196 |
+
messages = [
|
197 |
+
{
|
198 |
+
"role": "user",
|
199 |
+
"content": [
|
200 |
+
{"type": "image", "image": image},
|
201 |
+
{"type": "text", "text": prompt}
|
202 |
+
]
|
203 |
+
}
|
204 |
+
]
|
205 |
+
|
206 |
+
# Process input
|
207 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
208 |
+
inputs = processor(
|
209 |
+
text=[text],
|
210 |
+
images=[image],
|
211 |
+
padding=True,
|
212 |
+
return_tensors="pt"
|
213 |
+
)
|
214 |
+
|
215 |
+
# Move to GPU if available
|
216 |
+
if torch.cuda.is_available():
|
217 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
218 |
+
|
219 |
+
# Generate
|
220 |
+
with torch.no_grad():
|
221 |
+
generated_ids = model.generate(
|
222 |
+
**inputs,
|
223 |
+
do_sample=temperature > 0,
|
224 |
+
temperature=temperature,
|
225 |
+
top_p=top_p,
|
226 |
+
max_new_tokens=max_tokens,
|
227 |
+
)
|
228 |
+
|
229 |
+
# Decode
|
230 |
+
generated_ids_trimmed = [
|
231 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
232 |
+
]
|
233 |
+
response = processor.batch_decode(
|
234 |
+
generated_ids_trimmed,
|
235 |
+
skip_special_tokens=True,
|
236 |
+
clean_up_tokenization_spaces=False
|
237 |
+
)[0]
|
238 |
+
|
239 |
+
# Otherwise assume it's a LlavaGuard model
|
240 |
+
else:
|
241 |
+
# Process input for LlavaGuard models
|
242 |
+
inputs = processor(
|
243 |
+
prompt,
|
244 |
+
images=image,
|
245 |
+
return_tensors="pt"
|
246 |
+
)
|
247 |
+
|
248 |
+
# Move to GPU if available
|
249 |
+
if torch.cuda.is_available():
|
250 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
251 |
+
|
252 |
+
# Generate
|
253 |
+
with torch.no_grad():
|
254 |
+
generated_ids = model.generate(
|
255 |
+
**inputs,
|
256 |
+
do_sample=temperature > 0,
|
257 |
+
temperature=temperature,
|
258 |
+
top_p=top_p,
|
259 |
+
max_new_tokens=max_tokens,
|
260 |
+
)
|
261 |
+
|
262 |
+
# Decode
|
263 |
+
response = tokenizer.batch_decode(
|
264 |
+
generated_ids[:, inputs.input_ids.shape[1]:],
|
265 |
+
skip_special_tokens=True
|
266 |
+
)[0]
|
267 |
+
|
268 |
+
return response.strip()
|
269 |
+
|
270 |
+
except Exception as e:
|
271 |
+
logger.error(f"Error during inference: {str(e)}")
|
272 |
+
return f"Error during inference: {str(e)}"
|
273 |
|
274 |
+
# Gradio UI functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
get_window_url_params = """
|
276 |
function() {
|
277 |
const params = new URLSearchParams(window.location.search);
|
278 |
url_params = Object.fromEntries(params);
|
279 |
console.log(url_params);
|
280 |
return url_params;
|
281 |
+
}
|
282 |
"""
|
283 |
|
|
|
284 |
def load_demo(url_params, request: gr.Request):
|
285 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
286 |
+
models = get_model_list()
|
287 |
+
|
288 |
dropdown_update = gr.Dropdown(visible=True)
|
289 |
if "model" in url_params:
|
290 |
model = url_params["model"]
|
291 |
if model in models:
|
292 |
dropdown_update = gr.Dropdown(value=model, visible=True)
|
293 |
+
load_model(model)
|
294 |
|
295 |
state = default_conversation.copy()
|
296 |
return state, dropdown_update
|
297 |
|
|
|
298 |
def load_demo_refresh_model_list(request: gr.Request):
|
299 |
logger.info(f"load_demo. ip: {request.client.host}")
|
300 |
models = get_model_list()
|
|
|
305 |
)
|
306 |
return state, dropdown_update
|
307 |
|
|
|
308 |
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
309 |
with open(get_conv_log_filename(), "a") as fout:
|
310 |
data = {
|
|
|
316 |
}
|
317 |
fout.write(json.dumps(data) + "\n")
|
318 |
|
|
|
319 |
def upvote_last_response(state, model_selector, request: gr.Request):
|
320 |
logger.info(f"upvote. ip: {request.client.host}")
|
321 |
vote_last_response(state, "upvote", model_selector, request)
|
322 |
return ("",) + (disable_btn,) * 3
|
323 |
|
|
|
324 |
def downvote_last_response(state, model_selector, request: gr.Request):
|
325 |
logger.info(f"downvote. ip: {request.client.host}")
|
326 |
vote_last_response(state, "downvote", model_selector, request)
|
327 |
return ("",) + (disable_btn,) * 3
|
328 |
|
|
|
329 |
def flag_last_response(state, model_selector, request: gr.Request):
|
330 |
logger.info(f"flag. ip: {request.client.host}")
|
331 |
vote_last_response(state, "flag", model_selector, request)
|
332 |
return ("",) + (disable_btn,) * 3
|
333 |
|
|
|
334 |
def regenerate(state, image_process_mode, request: gr.Request):
|
335 |
logger.info(f"regenerate. ip: {request.client.host}")
|
336 |
state.messages[-1][-1] = None
|
|
|
340 |
state.skip_next = False
|
341 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
342 |
|
|
|
343 |
def clear_history(request: gr.Request):
|
344 |
logger.info(f"clear_history. ip: {request.client.host}")
|
345 |
state = default_conversation.copy()
|
346 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
347 |
|
|
|
348 |
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
349 |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
350 |
if len(text) <= 0 or image is None:
|
351 |
state.skip_next = True
|
352 |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
text = wrap_taxonomy(text)
|
355 |
if image is not None:
|
|
|
356 |
if '<image>' not in text:
|
|
|
357 |
text = text + '\n<image>'
|
358 |
text = (text, image, image_process_mode)
|
359 |
state = default_conversation.copy()
|
|
|
363 |
state.skip_next = False
|
364 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
365 |
|
|
|
366 |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
367 |
start_tstamp = time.time()
|
368 |
+
|
|
|
369 |
if state.skip_next:
|
370 |
# This generate call is skipped due to invalid inputs
|
371 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
372 |
return
|
373 |
|
374 |
+
# Get the prompt and images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
prompt = state.get_prompt()
|
|
|
376 |
all_images = state.get_images(return_pil=True)
|
377 |
+
|
378 |
+
if not all_images:
|
379 |
+
state.messages[-1][-1] = "Error: No image provided"
|
380 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
381 |
+
return
|
382 |
+
|
383 |
+
# Save image for logging
|
384 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
385 |
+
for image, hash_val in zip(all_images, all_image_hash):
|
386 |
t = datetime.datetime.now()
|
387 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash_val}.jpg")
|
388 |
if not os.path.isfile(filename):
|
389 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
390 |
image.save(filename)
|
391 |
+
|
392 |
+
# Load model if needed
|
393 |
+
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
394 |
+
load_model(model_selector)
|
395 |
+
|
396 |
+
# Run inference
|
397 |
+
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
398 |
state.messages[-1][-1] = output
|
399 |
|
400 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
401 |
|
402 |
finish_tstamp = time.time()
|
403 |
+
logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
|
404 |
|
405 |
with open(get_conv_log_filename(), "a") as fout:
|
406 |
data = {
|
407 |
"tstamp": round(finish_tstamp, 4),
|
408 |
"type": "chat",
|
409 |
+
"model": model_selector,
|
410 |
"start": round(start_tstamp, 4),
|
411 |
"finish": round(finish_tstamp, 4),
|
412 |
"state": state.dict(),
|
|
|
415 |
}
|
416 |
fout.write(json.dumps(data) + "\n")
|
417 |
|
418 |
+
# UI Components
|
419 |
+
title_markdown = """
|
420 |
# LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
|
421 |
[[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
|
422 |
[[Code](https://github.com/ml-research/LlavaGuard)]
|
423 |
[[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
|
424 |
[[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
|
425 |
[[LavaGuard](https://arxiv.org/abs/2406.05113)]
|
426 |
+
"""
|
427 |
|
428 |
+
tos_markdown = """
|
429 |
### Terms of use
|
430 |
By using this service, users are required to agree to the following terms:
|
431 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
432 |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
433 |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
434 |
+
"""
|
435 |
|
436 |
+
learn_more_markdown = """
|
437 |
### License
|
438 |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
439 |
+
"""
|
440 |
|
441 |
block_css = """
|
|
|
442 |
#buttons button {
|
443 |
min-width: min(120px,100%);
|
444 |
}
|
|
|
445 |
"""
|
446 |
|
|
|
|
|
|
|
447 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
448 |
+
models = get_model_list()
|
449 |
+
|
450 |
with gr.Accordion("Safety Risk Taxonomy", open=False) as accordion:
|
451 |
textbox = gr.Textbox(
|
452 |
label="Safety Risk Taxonomy",
|
|
|
455 |
container=True,
|
456 |
value=default_taxonomy,
|
457 |
lines=50)
|
458 |
+
|
459 |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
|
460 |
state = gr.State()
|
461 |
|
|
|
480 |
|
481 |
if cur_dir is None:
|
482 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
483 |
+
|
484 |
gr.Examples(examples=[
|
485 |
+
[f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if os.path.exists(f"{cur_dir}/examples/image{i}.png")
|
486 |
], inputs=imagebox)
|
487 |
|
488 |
with gr.Accordion("Parameters", open=False) as parameter_row:
|
489 |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
|
490 |
+
label="Temperature")
|
491 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
|
492 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
|
493 |
+
label="Max output tokens")
|
494 |
|
495 |
with gr.Column(scale=8):
|
496 |
chatbot = gr.Chatbot(
|
|
|
508 |
upvote_btn = gr.Button(value="π Upvote", interactive=False)
|
509 |
downvote_btn = gr.Button(value="π Downvote", interactive=False)
|
510 |
flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
|
|
|
511 |
regenerate_btn = gr.Button(value="π Regenerate", interactive=False)
|
512 |
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
|
513 |
|
|
|
518 |
|
519 |
# Register listeners
|
520 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
521 |
+
|
522 |
upvote_btn.click(
|
523 |
upvote_last_response,
|
524 |
[state, model_selector],
|
525 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
526 |
)
|
527 |
+
|
528 |
downvote_btn.click(
|
529 |
downvote_last_response,
|
530 |
[state, model_selector],
|
531 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
532 |
)
|
533 |
+
|
534 |
flag_btn.click(
|
535 |
flag_last_response,
|
536 |
[state, model_selector],
|
537 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
538 |
)
|
539 |
+
|
540 |
+
model_selector.change(
|
541 |
+
load_model,
|
542 |
+
[model_selector],
|
543 |
+
None
|
544 |
+
)
|
545 |
|
546 |
regenerate_btn.click(
|
547 |
regenerate,
|
|
|
584 |
concurrency_limit=concurrency_count
|
585 |
)
|
586 |
|
587 |
+
demo.load(
|
588 |
+
load_demo_refresh_model_list,
|
589 |
+
None,
|
590 |
+
[state, model_selector],
|
591 |
+
queue=False
|
592 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
return demo
|
595 |
|
|
|
598 |
parser = argparse.ArgumentParser()
|
599 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
600 |
parser.add_argument("--port", type=int)
|
|
|
601 |
parser.add_argument("--concurrency-count", type=int, default=5)
|
|
|
602 |
parser.add_argument("--share", action="store_true")
|
603 |
parser.add_argument("--moderate", action="store_true")
|
604 |
parser.add_argument("--embed", action="store_true")
|
605 |
args = parser.parse_args()
|
|
|
|
|
|
|
606 |
|
607 |
+
# Create log directory if it doesn't exist
|
608 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
609 |
+
|
610 |
+
# GPU Check
|
611 |
+
if torch.cuda.is_available():
|
612 |
+
logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
|
613 |
+
else:
|
614 |
+
logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
|
615 |
+
|
616 |
+
# Hugging Face token handling
|
617 |
api_key = os.getenv("token")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
if api_key:
|
619 |
+
from huggingface_hub import login
|
620 |
+
login(token=api_key)
|
621 |
+
logger.info("Logged in to Hugging Face Hub")
|
622 |
+
|
623 |
+
# Load initial model
|
624 |
+
models = get_model_list()
|
625 |
+
model_path = os.getenv("model", models[0])
|
626 |
+
logger.info(f"Initial model selected: {model_path}")
|
627 |
+
load_model(model_path)
|
628 |
+
|
629 |
+
# Launch Gradio app
|
|
|
|
|
|
|
630 |
try:
|
631 |
+
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
632 |
demo.queue(
|
633 |
status_update_rate=10,
|
634 |
api_open=False
|
|
|
637 |
server_port=args.port,
|
638 |
share=args.share
|
639 |
)
|
|
|
640 |
except Exception as e:
|
641 |
+
logger.error(f"Error launching demo: {e}")
|
642 |
+
sys.exit(1)
|
|
|
|