LukasHug commited on
Commit
ae2dd52
Β·
verified Β·
1 Parent(s): f43ad4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -219
app.py CHANGED
@@ -2,120 +2,299 @@ import argparse
2
  import datetime
3
  import hashlib
4
  import json
 
5
  import os
6
  import sys
7
  import time
8
- import warnings
9
 
10
  import gradio as gr
11
- import spaces
12
  import torch
13
-
14
- from builder import load_pretrained_model
15
- from llava.constants import IMAGE_TOKEN_INDEX
16
- from llava.constants import LOGDIR
17
- from llava.conversation import (default_conversation, conv_templates)
18
- from llava.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token
19
- from llava.utils import (build_logger, violates_moderation, moderation_msg)
20
- from taxonomy import wrap_taxonomy, default_taxonomy
21
-
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def clear_conv(conv):
24
  conv.messages = []
25
  return conv
26
 
 
 
 
 
 
27
 
28
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
29
-
30
- headers = {"User-Agent": "LLaVA Client"}
31
-
32
  no_change_btn = gr.Button()
33
  enable_btn = gr.Button(interactive=True)
34
  disable_btn = gr.Button(interactive=False)
35
 
36
- priority = {
37
- "LlavaGuard-7B": "aaaaaaa",
38
- "LlavaGuard-13B": "aaaaaab",
39
- "LlavaGuard-34B": "aaaaaac",
40
- }
41
-
42
-
43
- @spaces.GPU
44
- def run_llava(prompt, pil_image, temperature, top_p, max_new_tokens):
45
- image_size = pil_image.size
46
- image_tensor = image_processor.preprocess(pil_image, return_tensors='pt')['pixel_values'].half().cuda()
47
- # image_tensor = image_tensor.to(model.device, dtype=torch.float16)
48
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
49
- input_ids = input_ids.unsqueeze(0).cuda()
50
- with torch.inference_mode():
51
- output_ids = model.generate(
52
- input_ids,
53
- images=image_tensor,
54
- image_sizes=[image_size],
55
- do_sample=True,
56
- temperature=temperature,
57
- top_p=top_p,
58
- top_k=50,
59
- num_beams=2,
60
- max_new_tokens=max_new_tokens,
61
- use_cache=True,
62
- stopping_criteria=[KeywordsStoppingCriteria(['}'], tokenizer, input_ids)]
63
- )
64
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
65
-
66
- return outputs[0].strip()
67
-
68
-
69
- def load_selected_model(model_path):
70
- model_name = model_path.split("/")[-1]
71
- global tokenizer, model, image_processor, context_len
72
- with warnings.catch_warnings(record=True) as w:
73
- warnings.simplefilter("always")
74
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
75
- for warning in w:
76
- if "vision" not in str(warning.message).lower():
77
- print(warning.message)
78
- model.config.tokenizer_model_max_length = 2048 * 2
79
 
 
 
 
 
 
 
 
 
80
 
81
  def get_conv_log_filename():
82
  t = datetime.datetime.now()
83
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
 
84
  return name
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- def get_model_list():
88
- models = [
89
- 'AIML-TUDA/LlavaGuard-7B',
90
- 'AIML-TUDA/LlavaGuard-v1.1-7B-hf',
91
- 'AIML-TUDA/LlavaGuard-13B',
92
- 'AIML-TUDA/LlavaGuard-v1.1-13B-hf']
93
- return models
94
-
95
-
96
  get_window_url_params = """
97
  function() {
98
  const params = new URLSearchParams(window.location.search);
99
  url_params = Object.fromEntries(params);
100
  console.log(url_params);
101
  return url_params;
102
- }
103
  """
104
 
105
-
106
  def load_demo(url_params, request: gr.Request):
107
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
108
-
 
109
  dropdown_update = gr.Dropdown(visible=True)
110
  if "model" in url_params:
111
  model = url_params["model"]
112
  if model in models:
113
  dropdown_update = gr.Dropdown(value=model, visible=True)
 
114
 
115
  state = default_conversation.copy()
116
  return state, dropdown_update
117
 
118
-
119
  def load_demo_refresh_model_list(request: gr.Request):
120
  logger.info(f"load_demo. ip: {request.client.host}")
121
  models = get_model_list()
@@ -126,7 +305,6 @@ def load_demo_refresh_model_list(request: gr.Request):
126
  )
127
  return state, dropdown_update
128
 
129
-
130
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
131
  with open(get_conv_log_filename(), "a") as fout:
132
  data = {
@@ -138,25 +316,21 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
138
  }
139
  fout.write(json.dumps(data) + "\n")
140
 
141
-
142
  def upvote_last_response(state, model_selector, request: gr.Request):
143
  logger.info(f"upvote. ip: {request.client.host}")
144
  vote_last_response(state, "upvote", model_selector, request)
145
  return ("",) + (disable_btn,) * 3
146
 
147
-
148
  def downvote_last_response(state, model_selector, request: gr.Request):
149
  logger.info(f"downvote. ip: {request.client.host}")
150
  vote_last_response(state, "downvote", model_selector, request)
151
  return ("",) + (disable_btn,) * 3
152
 
153
-
154
  def flag_last_response(state, model_selector, request: gr.Request):
155
  logger.info(f"flag. ip: {request.client.host}")
156
  vote_last_response(state, "flag", model_selector, request)
157
  return ("",) + (disable_btn,) * 3
158
 
159
-
160
  def regenerate(state, image_process_mode, request: gr.Request):
161
  logger.info(f"regenerate. ip: {request.client.host}")
162
  state.messages[-1][-1] = None
@@ -166,30 +340,20 @@ def regenerate(state, image_process_mode, request: gr.Request):
166
  state.skip_next = False
167
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
168
 
169
-
170
  def clear_history(request: gr.Request):
171
  logger.info(f"clear_history. ip: {request.client.host}")
172
  state = default_conversation.copy()
173
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
174
 
175
-
176
  def add_text(state, text, image, image_process_mode, request: gr.Request):
177
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
178
  if len(text) <= 0 or image is None:
179
  state.skip_next = True
180
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
181
- if args.moderate:
182
- flagged = violates_moderation(text)
183
- if flagged:
184
- state.skip_next = True
185
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
186
- no_change_btn,) * 5
187
 
188
  text = wrap_taxonomy(text)
189
  if image is not None:
190
- text = text # Hard cut-off for images
191
  if '<image>' not in text:
192
- # text = '<Image><image></Image>' + text
193
  text = text + '\n<image>'
194
  text = (text, image, image_process_mode)
195
  state = default_conversation.copy()
@@ -199,83 +363,50 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
199
  state.skip_next = False
200
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
201
 
202
-
203
  def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
204
  start_tstamp = time.time()
205
- model_name = model_selector
206
-
207
  if state.skip_next:
208
  # This generate call is skipped due to invalid inputs
209
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
210
  return
211
 
212
- if len(state.messages) == state.offset + 2:
213
- # First round of conversation
214
- if "llava" in model_name.lower():
215
- if 'llama-2' in model_name.lower():
216
- template_name = "llava_llama_2"
217
- elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
218
- if 'orca' in model_name.lower():
219
- template_name = "mistral_orca"
220
- elif 'hermes' in model_name.lower():
221
- template_name = "chatml_direct"
222
- else:
223
- template_name = "mistral_instruct"
224
- elif 'llava-v1.6-34b' in model_name.lower():
225
- template_name = "chatml_direct"
226
- elif "v1" in model_name.lower():
227
- if 'mmtag' in model_name.lower():
228
- template_name = "v1_mmtag"
229
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
230
- template_name = "v1_mmtag"
231
- else:
232
- template_name = "llava_v1"
233
- elif "mpt" in model_name.lower():
234
- template_name = "mpt"
235
- else:
236
- if 'mmtag' in model_name.lower():
237
- template_name = "v0_mmtag"
238
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
239
- template_name = "v0_mmtag"
240
- else:
241
- template_name = "llava_v0"
242
- elif "mpt" in model_name:
243
- template_name = "mpt_text"
244
- elif "llama-2" in model_name:
245
- template_name = "llama_2"
246
- else:
247
- template_name = "vicuna_v1"
248
- new_state = conv_templates[template_name].copy()
249
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
250
- new_state.append_message(new_state.roles[1], None)
251
- state = new_state
252
-
253
- # Construct prompt
254
  prompt = state.get_prompt()
255
-
256
  all_images = state.get_images(return_pil=True)
 
 
 
 
 
 
 
257
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
258
- for image, hash in zip(all_images, all_image_hash):
259
  t = datetime.datetime.now()
260
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
261
  if not os.path.isfile(filename):
262
  os.makedirs(os.path.dirname(filename), exist_ok=True)
263
  image.save(filename)
264
-
265
- output = run_llava(prompt, all_images[0], temperature, top_p, max_new_tokens)
266
-
 
 
 
 
267
  state.messages[-1][-1] = output
268
 
269
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
270
 
271
  finish_tstamp = time.time()
272
- logger.info(f"{output}")
273
 
274
  with open(get_conv_log_filename(), "a") as fout:
275
  data = {
276
  "tstamp": round(finish_tstamp, 4),
277
  "type": "chat",
278
- "model": model_name,
279
  "start": round(start_tstamp, 4),
280
  "finish": round(finish_tstamp, 4),
281
  "state": state.dict(),
@@ -284,41 +415,38 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
284
  }
285
  fout.write(json.dumps(data) + "\n")
286
 
287
-
288
- title_markdown = ("""
289
  # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
290
  [[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
291
  [[Code](https://github.com/ml-research/LlavaGuard)]
292
  [[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
293
  [[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
294
  [[LavaGuard](https://arxiv.org/abs/2406.05113)]
295
- """)
296
 
297
- tos_markdown = ("""
298
  ### Terms of use
299
  By using this service, users are required to agree to the following terms:
300
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
301
  Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
302
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
303
- """)
304
 
305
- learn_more_markdown = ("""
306
  ### License
307
  The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
308
- """)
309
 
310
  block_css = """
311
-
312
  #buttons button {
313
  min-width: min(120px,100%);
314
  }
315
-
316
  """
317
 
318
- taxonomies = ["Default", "Modified w/ O1 non-violating", "Default message 3"]
319
-
320
-
321
  def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
 
 
322
  with gr.Accordion("Safety Risk Taxonomy", open=False) as accordion:
323
  textbox = gr.Textbox(
324
  label="Safety Risk Taxonomy",
@@ -327,6 +455,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
327
  container=True,
328
  value=default_taxonomy,
329
  lines=50)
 
330
  with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
331
  state = gr.State()
332
 
@@ -351,16 +480,17 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
351
 
352
  if cur_dir is None:
353
  cur_dir = os.path.dirname(os.path.abspath(__file__))
 
354
  gr.Examples(examples=[
355
- [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6)
356
  ], inputs=imagebox)
357
 
358
  with gr.Accordion("Parameters", open=False) as parameter_row:
359
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
360
- label="Temperature", )
361
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P", )
362
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
363
- label="Max output tokens", )
364
 
365
  with gr.Column(scale=8):
366
  chatbot = gr.Chatbot(
@@ -378,7 +508,6 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
378
  upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
379
  downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
380
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
381
- # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
382
  regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
383
  clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
384
 
@@ -389,26 +518,30 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
389
 
390
  # Register listeners
391
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
 
392
  upvote_btn.click(
393
  upvote_last_response,
394
  [state, model_selector],
395
  [textbox, upvote_btn, downvote_btn, flag_btn]
396
  )
 
397
  downvote_btn.click(
398
  downvote_last_response,
399
  [state, model_selector],
400
  [textbox, upvote_btn, downvote_btn, flag_btn]
401
  )
 
402
  flag_btn.click(
403
  flag_last_response,
404
  [state, model_selector],
405
  [textbox, upvote_btn, downvote_btn, flag_btn]
406
  )
407
-
408
- # model_selector.change(
409
- # load_selected_model,
410
- # [model_selector],
411
- # )
 
412
 
413
  regenerate_btn.click(
414
  regenerate,
@@ -451,22 +584,12 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
451
  concurrency_limit=concurrency_count
452
  )
453
 
454
- if args.model_list_mode == "once":
455
- demo.load(
456
- load_demo,
457
- [url_params],
458
- [state, model_selector],
459
- js=get_window_url_params
460
- )
461
- elif args.model_list_mode == "reload":
462
- demo.load(
463
- load_demo_refresh_model_list,
464
- None,
465
- [state, model_selector],
466
- queue=False
467
- )
468
- else:
469
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
470
 
471
  return demo
472
 
@@ -475,51 +598,37 @@ if __name__ == "__main__":
475
  parser = argparse.ArgumentParser()
476
  parser.add_argument("--host", type=str, default="0.0.0.0")
477
  parser.add_argument("--port", type=int)
478
- parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
479
  parser.add_argument("--concurrency-count", type=int, default=5)
480
- parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
481
  parser.add_argument("--share", action="store_true")
482
  parser.add_argument("--moderate", action="store_true")
483
  parser.add_argument("--embed", action="store_true")
484
  args = parser.parse_args()
485
- models = []
486
-
487
- title_markdown += """
488
 
489
- ONLY WORKS WITH GPU!
490
-
491
- Set the environment variable `model` to change the model:
492
- ['AIML-TUDA/LlavaGuard-7B'](https://huggingface.co/AIML-TUDA/LlavaGuard-7B),
493
- ['AIML-TUDA/LlavaGuard-13B'](https://huggingface.co/AIML-TUDA/LlavaGuard-13B),
494
- ['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
495
- """
496
- print(f"args: {args}")
497
- concurrency_count = int(os.getenv("concurrency_count", 5))
 
498
  api_key = os.getenv("token")
499
-
500
-
501
- models = get_model_list()
502
- bits = int(os.getenv("bits", 16))
503
- model = os.getenv("model", models[0])
504
- available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
505
- model_path, model_name = model, model.split("/")[0]
506
  if api_key:
507
- cmd = f"huggingface-cli login --token {api_key} --add-to-git-credential"
508
- os.system(cmd)
509
- else:
510
- if '/workspace' not in sys.path:
511
- sys.path.append('/workspace')
512
- from llavaguard.hf_utils import set_up_env_and_token
513
- api_key = set_up_env_and_token(read=True, write=False)
514
- model_path = '/common-repos/LlavaGuard/models/LlavaGuard-v1.1-7b-full/smid_and_crawled_v2_with_augmented_policies/json-v16/llava'
515
-
516
- print(f"Loading model {model_path}")
517
- load_selected_model(model_path)
518
- model.config.tokenizer_model_max_length = 2048 * 2
519
-
520
- exit_status = 0
521
  try:
522
- demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
523
  demo.queue(
524
  status_update_rate=10,
525
  api_open=False
@@ -528,9 +637,6 @@ Set the environment variable `model` to change the model:
528
  server_port=args.port,
529
  share=args.share
530
  )
531
-
532
  except Exception as e:
533
- print(e)
534
- exit_status = 1
535
- finally:
536
- sys.exit(exit_status)
 
2
  import datetime
3
  import hashlib
4
  import json
5
+ import logging
6
  import os
7
  import sys
8
  import time
 
9
 
10
  import gradio as gr
 
11
  import torch
12
+ from PIL import Image
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoProcessor,
16
+ AutoTokenizer,
17
+ Qwen2_5_VLForConditionalGeneration
18
+ )
19
+
20
+ from taxonomy import policy_v1
21
+
22
+ # Set up logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
26
+ handlers=[
27
+ logging.FileHandler("gradio_web_server.log"),
28
+ logging.StreamHandler()
29
+ ]
30
+ )
31
+ logger = logging.getLogger("gradio_web_server")
32
+
33
+ # Constants
34
+ LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
35
+ os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
36
+
37
+ default_taxonomy = policy_v1
38
+
39
+ class Conversation:
40
+ def __init__(self):
41
+ self.messages = []
42
+ self.roles = ["user", "assistant"]
43
+ self.offset = 0
44
+ self.skip_next = False
45
+
46
+ def append_message(self, role, message):
47
+ self.messages.append([role, message])
48
+
49
+ def to_gradio_chatbot(self):
50
+ ret = []
51
+ for role, message in self.messages:
52
+ if message is None:
53
+ continue
54
+ if role == self.roles[0]:
55
+ if isinstance(message, tuple):
56
+ ret.append([self.render_user_message(message[0]), None])
57
+ else:
58
+ ret.append([self.render_user_message(message), None])
59
+ elif role == self.roles[1]:
60
+ if ret[-1][1] is None:
61
+ ret[-1][1] = message
62
+ else:
63
+ ret.append([None, message])
64
+ else:
65
+ raise ValueError(f"Invalid role: {role}")
66
+ return ret
67
+
68
+ def render_user_message(self, message):
69
+ if "<image>" in message:
70
+ return message.replace("<image>", "")
71
+ return message
72
+
73
+ def dict(self):
74
+ return {
75
+ "messages": self.messages,
76
+ "roles": self.roles,
77
+ "offset": self.offset,
78
+ "skip_next": self.skip_next,
79
+ }
80
+
81
+ def get_prompt(self):
82
+ prompt = ""
83
+ for role, message in self.messages:
84
+ if message is None:
85
+ continue
86
+ if isinstance(message, tuple):
87
+ message = message[0]
88
+ if role == self.roles[0]:
89
+ prompt += f"USER: {message}\n"
90
+ else:
91
+ prompt += f"ASSISTANT: {message}\n"
92
+ return prompt + "ASSISTANT: "
93
+
94
+ def get_images(self, return_pil=False):
95
+ images = []
96
+ for role, message in self.messages:
97
+ if isinstance(message, tuple) and len(message) > 1:
98
+ if isinstance(message[1], Image.Image):
99
+ images.append(message[1] if return_pil else message[1])
100
+ return images
101
+
102
+ def copy(self):
103
+ new_conv = Conversation()
104
+ new_conv.messages = self.messages.copy()
105
+ new_conv.roles = self.roles.copy()
106
+ new_conv.offset = self.offset
107
+ new_conv.skip_next = self.skip_next
108
+ return new_conv
109
+
110
+ default_conversation = Conversation()
111
+
112
+ # Model and processor storage
113
+ tokenizer = None
114
+ model = None
115
+ processor = None
116
+ context_len = 2048
117
+
118
+ # Helper functions
119
  def clear_conv(conv):
120
  conv.messages = []
121
  return conv
122
 
123
+ def wrap_taxonomy(text):
124
+ """Wraps user input with taxonomy if not already present"""
125
+ if policy_v1 not in text:
126
+ return policy_v1 + "\n\n" + text
127
+ return text
128
 
129
+ # UI component states
 
 
 
130
  no_change_btn = gr.Button()
131
  enable_btn = gr.Button(interactive=True)
132
  disable_btn = gr.Button(interactive=False)
133
 
134
+ # Model loading function
135
+ def load_model(model_path):
136
+ global tokenizer, model, processor, context_len
137
+
138
+ logger.info(f"Loading model: {model_path}")
139
+
140
+ try:
141
+ # Check if it's a Qwen model
142
+ if "qwenguard" in model_path.lower():
143
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
144
+ model_path,
145
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
146
+ device_map="auto" if torch.cuda.is_available() else None
147
+ )
148
+ processor = AutoProcessor.from_pretrained(model_path)
149
+ tokenizer = processor.tokenizer
150
+
151
+ # Otherwise assume it's a LlavaGuard model
152
+ else:
153
+ model = AutoModelForCausalLM.from_pretrained(
154
+ model_path,
155
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
156
+ device_map="auto" if torch.cuda.is_available() else None,
157
+ trust_remote_code=True
158
+ )
159
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
160
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
161
+
162
+ context_len = getattr(model.config, "max_position_embeddings", 2048)
163
+ logger.info(f"Model {model_path} loaded successfully")
164
+ return True
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error loading model {model_path}: {str(e)}")
168
+ return False
 
 
 
 
 
 
 
 
169
 
170
+ def get_model_list():
171
+ models = [
172
+ 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
173
+ 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
174
+ 'AIML-TUDA/QwenGuard-v1.2-7B',
175
+ 'AIML-TUDA/QwenGuard-v1.2-3B'
176
+ ]
177
+ return models
178
 
179
  def get_conv_log_filename():
180
  t = datetime.datetime.now()
181
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
182
+ os.makedirs(os.path.dirname(name), exist_ok=True)
183
  return name
184
 
185
+ # Inference function
186
+ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
187
+ global model, tokenizer, processor
188
+
189
+ if model is None or processor is None:
190
+ return "Model not loaded. Please select a model first."
191
+
192
+ try:
193
+ # Check if it's a Qwen model
194
+ if isinstance(model, Qwen2_5_VLForConditionalGeneration):
195
+ # Format for Qwen models
196
+ messages = [
197
+ {
198
+ "role": "user",
199
+ "content": [
200
+ {"type": "image", "image": image},
201
+ {"type": "text", "text": prompt}
202
+ ]
203
+ }
204
+ ]
205
+
206
+ # Process input
207
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
208
+ inputs = processor(
209
+ text=[text],
210
+ images=[image],
211
+ padding=True,
212
+ return_tensors="pt"
213
+ )
214
+
215
+ # Move to GPU if available
216
+ if torch.cuda.is_available():
217
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
218
+
219
+ # Generate
220
+ with torch.no_grad():
221
+ generated_ids = model.generate(
222
+ **inputs,
223
+ do_sample=temperature > 0,
224
+ temperature=temperature,
225
+ top_p=top_p,
226
+ max_new_tokens=max_tokens,
227
+ )
228
+
229
+ # Decode
230
+ generated_ids_trimmed = [
231
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
232
+ ]
233
+ response = processor.batch_decode(
234
+ generated_ids_trimmed,
235
+ skip_special_tokens=True,
236
+ clean_up_tokenization_spaces=False
237
+ )[0]
238
+
239
+ # Otherwise assume it's a LlavaGuard model
240
+ else:
241
+ # Process input for LlavaGuard models
242
+ inputs = processor(
243
+ prompt,
244
+ images=image,
245
+ return_tensors="pt"
246
+ )
247
+
248
+ # Move to GPU if available
249
+ if torch.cuda.is_available():
250
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
251
+
252
+ # Generate
253
+ with torch.no_grad():
254
+ generated_ids = model.generate(
255
+ **inputs,
256
+ do_sample=temperature > 0,
257
+ temperature=temperature,
258
+ top_p=top_p,
259
+ max_new_tokens=max_tokens,
260
+ )
261
+
262
+ # Decode
263
+ response = tokenizer.batch_decode(
264
+ generated_ids[:, inputs.input_ids.shape[1]:],
265
+ skip_special_tokens=True
266
+ )[0]
267
+
268
+ return response.strip()
269
+
270
+ except Exception as e:
271
+ logger.error(f"Error during inference: {str(e)}")
272
+ return f"Error during inference: {str(e)}"
273
 
274
+ # Gradio UI functions
 
 
 
 
 
 
 
 
275
  get_window_url_params = """
276
  function() {
277
  const params = new URLSearchParams(window.location.search);
278
  url_params = Object.fromEntries(params);
279
  console.log(url_params);
280
  return url_params;
281
+ }
282
  """
283
 
 
284
  def load_demo(url_params, request: gr.Request):
285
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
286
+ models = get_model_list()
287
+
288
  dropdown_update = gr.Dropdown(visible=True)
289
  if "model" in url_params:
290
  model = url_params["model"]
291
  if model in models:
292
  dropdown_update = gr.Dropdown(value=model, visible=True)
293
+ load_model(model)
294
 
295
  state = default_conversation.copy()
296
  return state, dropdown_update
297
 
 
298
  def load_demo_refresh_model_list(request: gr.Request):
299
  logger.info(f"load_demo. ip: {request.client.host}")
300
  models = get_model_list()
 
305
  )
306
  return state, dropdown_update
307
 
 
308
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
309
  with open(get_conv_log_filename(), "a") as fout:
310
  data = {
 
316
  }
317
  fout.write(json.dumps(data) + "\n")
318
 
 
319
  def upvote_last_response(state, model_selector, request: gr.Request):
320
  logger.info(f"upvote. ip: {request.client.host}")
321
  vote_last_response(state, "upvote", model_selector, request)
322
  return ("",) + (disable_btn,) * 3
323
 
 
324
  def downvote_last_response(state, model_selector, request: gr.Request):
325
  logger.info(f"downvote. ip: {request.client.host}")
326
  vote_last_response(state, "downvote", model_selector, request)
327
  return ("",) + (disable_btn,) * 3
328
 
 
329
  def flag_last_response(state, model_selector, request: gr.Request):
330
  logger.info(f"flag. ip: {request.client.host}")
331
  vote_last_response(state, "flag", model_selector, request)
332
  return ("",) + (disable_btn,) * 3
333
 
 
334
  def regenerate(state, image_process_mode, request: gr.Request):
335
  logger.info(f"regenerate. ip: {request.client.host}")
336
  state.messages[-1][-1] = None
 
340
  state.skip_next = False
341
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
342
 
 
343
  def clear_history(request: gr.Request):
344
  logger.info(f"clear_history. ip: {request.client.host}")
345
  state = default_conversation.copy()
346
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
347
 
 
348
  def add_text(state, text, image, image_process_mode, request: gr.Request):
349
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
350
  if len(text) <= 0 or image is None:
351
  state.skip_next = True
352
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
 
 
 
 
 
 
353
 
354
  text = wrap_taxonomy(text)
355
  if image is not None:
 
356
  if '<image>' not in text:
 
357
  text = text + '\n<image>'
358
  text = (text, image, image_process_mode)
359
  state = default_conversation.copy()
 
363
  state.skip_next = False
364
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
365
 
 
366
  def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
367
  start_tstamp = time.time()
368
+
 
369
  if state.skip_next:
370
  # This generate call is skipped due to invalid inputs
371
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
372
  return
373
 
374
+ # Get the prompt and images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  prompt = state.get_prompt()
 
376
  all_images = state.get_images(return_pil=True)
377
+
378
+ if not all_images:
379
+ state.messages[-1][-1] = "Error: No image provided"
380
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
381
+ return
382
+
383
+ # Save image for logging
384
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
385
+ for image, hash_val in zip(all_images, all_image_hash):
386
  t = datetime.datetime.now()
387
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash_val}.jpg")
388
  if not os.path.isfile(filename):
389
  os.makedirs(os.path.dirname(filename), exist_ok=True)
390
  image.save(filename)
391
+
392
+ # Load model if needed
393
+ if model is None or model_selector != getattr(model, "_name_or_path", ""):
394
+ load_model(model_selector)
395
+
396
+ # Run inference
397
+ output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
398
  state.messages[-1][-1] = output
399
 
400
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
401
 
402
  finish_tstamp = time.time()
403
+ logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
404
 
405
  with open(get_conv_log_filename(), "a") as fout:
406
  data = {
407
  "tstamp": round(finish_tstamp, 4),
408
  "type": "chat",
409
+ "model": model_selector,
410
  "start": round(start_tstamp, 4),
411
  "finish": round(finish_tstamp, 4),
412
  "state": state.dict(),
 
415
  }
416
  fout.write(json.dumps(data) + "\n")
417
 
418
+ # UI Components
419
+ title_markdown = """
420
  # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
421
  [[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
422
  [[Code](https://github.com/ml-research/LlavaGuard)]
423
  [[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
424
  [[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
425
  [[LavaGuard](https://arxiv.org/abs/2406.05113)]
426
+ """
427
 
428
+ tos_markdown = """
429
  ### Terms of use
430
  By using this service, users are required to agree to the following terms:
431
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
432
  Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
433
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
434
+ """
435
 
436
+ learn_more_markdown = """
437
  ### License
438
  The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
439
+ """
440
 
441
  block_css = """
 
442
  #buttons button {
443
  min-width: min(120px,100%);
444
  }
 
445
  """
446
 
 
 
 
447
  def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
448
+ models = get_model_list()
449
+
450
  with gr.Accordion("Safety Risk Taxonomy", open=False) as accordion:
451
  textbox = gr.Textbox(
452
  label="Safety Risk Taxonomy",
 
455
  container=True,
456
  value=default_taxonomy,
457
  lines=50)
458
+
459
  with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
460
  state = gr.State()
461
 
 
480
 
481
  if cur_dir is None:
482
  cur_dir = os.path.dirname(os.path.abspath(__file__))
483
+
484
  gr.Examples(examples=[
485
+ [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if os.path.exists(f"{cur_dir}/examples/image{i}.png")
486
  ], inputs=imagebox)
487
 
488
  with gr.Accordion("Parameters", open=False) as parameter_row:
489
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
490
+ label="Temperature")
491
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
492
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
493
+ label="Max output tokens")
494
 
495
  with gr.Column(scale=8):
496
  chatbot = gr.Chatbot(
 
508
  upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
509
  downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
510
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
 
511
  regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
512
  clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
513
 
 
518
 
519
  # Register listeners
520
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
521
+
522
  upvote_btn.click(
523
  upvote_last_response,
524
  [state, model_selector],
525
  [textbox, upvote_btn, downvote_btn, flag_btn]
526
  )
527
+
528
  downvote_btn.click(
529
  downvote_last_response,
530
  [state, model_selector],
531
  [textbox, upvote_btn, downvote_btn, flag_btn]
532
  )
533
+
534
  flag_btn.click(
535
  flag_last_response,
536
  [state, model_selector],
537
  [textbox, upvote_btn, downvote_btn, flag_btn]
538
  )
539
+
540
+ model_selector.change(
541
+ load_model,
542
+ [model_selector],
543
+ None
544
+ )
545
 
546
  regenerate_btn.click(
547
  regenerate,
 
584
  concurrency_limit=concurrency_count
585
  )
586
 
587
+ demo.load(
588
+ load_demo_refresh_model_list,
589
+ None,
590
+ [state, model_selector],
591
+ queue=False
592
+ )
 
 
 
 
 
 
 
 
 
 
593
 
594
  return demo
595
 
 
598
  parser = argparse.ArgumentParser()
599
  parser.add_argument("--host", type=str, default="0.0.0.0")
600
  parser.add_argument("--port", type=int)
 
601
  parser.add_argument("--concurrency-count", type=int, default=5)
 
602
  parser.add_argument("--share", action="store_true")
603
  parser.add_argument("--moderate", action="store_true")
604
  parser.add_argument("--embed", action="store_true")
605
  args = parser.parse_args()
 
 
 
606
 
607
+ # Create log directory if it doesn't exist
608
+ os.makedirs(LOGDIR, exist_ok=True)
609
+
610
+ # GPU Check
611
+ if torch.cuda.is_available():
612
+ logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
613
+ else:
614
+ logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
615
+
616
+ # Hugging Face token handling
617
  api_key = os.getenv("token")
 
 
 
 
 
 
 
618
  if api_key:
619
+ from huggingface_hub import login
620
+ login(token=api_key)
621
+ logger.info("Logged in to Hugging Face Hub")
622
+
623
+ # Load initial model
624
+ models = get_model_list()
625
+ model_path = os.getenv("model", models[0])
626
+ logger.info(f"Initial model selected: {model_path}")
627
+ load_model(model_path)
628
+
629
+ # Launch Gradio app
 
 
 
630
  try:
631
+ demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
632
  demo.queue(
633
  status_update_rate=10,
634
  api_open=False
 
637
  server_port=args.port,
638
  share=args.share
639
  )
 
640
  except Exception as e:
641
+ logger.error(f"Error launching demo: {e}")
642
+ sys.exit(1)