LukasHug commited on
Commit
5881559
·
verified ·
1 Parent(s): 03ed7fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -175
app.py CHANGED
@@ -16,6 +16,7 @@ from transformers import (
16
  Qwen2_5_VLForConditionalGeneration,
17
  LlavaOnevisionForConditionalGeneration
18
  )
 
19
 
20
  from taxonomy import policy_v1
21
 
@@ -36,94 +37,77 @@ os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
36
 
37
  default_taxonomy = policy_v1
38
 
39
- class Conversation:
 
40
  def __init__(self):
41
- self.messages = []
42
- self.roles = ["user", "assistant"]
43
- self.offset = 0
44
  self.skip_next = False
 
 
 
 
 
 
 
 
45
 
46
- def append_message(self, role, message):
47
- self.messages.append([role, message])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def to_gradio_chatbot(self):
 
 
 
50
  ret = []
51
- for role, message in self.messages:
52
- if message is None:
53
- continue
54
- if role == self.roles[0]:
55
- if isinstance(message, tuple):
56
- ret.append([self.render_user_message(message[0]), None])
57
- else:
58
- ret.append([self.render_user_message(message), None])
59
- elif role == self.roles[1]:
60
- if ret[-1][1] is None:
61
- ret[-1][1] = message
62
- else:
63
- ret.append([None, message])
64
- else:
65
- raise ValueError(f"Invalid role: {role}")
66
  return ret
67
-
68
- def render_user_message(self, message):
69
- if "<image>" in message:
70
- return message.replace("<image>", "")
71
- return message
72
 
73
  def dict(self):
74
- # Create a serializable version of messages
75
- serialized_messages = []
76
- for role, message in self.messages:
77
- if isinstance(message, tuple) and len(message) > 1:
78
- # If the message contains an image (tuple format)
79
- if isinstance(message[1], Image.Image):
80
- # Just keep the text part and ignore the image
81
- serialized_message = (message[0], "[IMAGE_IGNORED]")
82
- else:
83
- # For non-image tuples, keep as is
84
- serialized_message = message
85
- else:
86
- # For non-tuple messages, keep as is
87
- serialized_message = message
88
- serialized_messages.append([role, serialized_message])
89
-
90
  return {
91
- "messages": serialized_messages,
92
- "roles": self.roles,
93
- "offset": self.offset,
94
- "skip_next": self.skip_next,
95
  }
96
-
97
- def get_prompt(self):
98
- prompt = ""
99
- for role, message in self.messages:
100
- if message is None:
101
- continue
102
- if isinstance(message, tuple):
103
- message = message[0]
104
- if role == self.roles[0]:
105
- prompt += f"USER: {message}\n"
106
- else:
107
- prompt += f"ASSISTANT: {message}\n"
108
- return prompt + "ASSISTANT: "
109
-
110
- def get_images(self, return_pil=False):
111
- images = []
112
- for role, message in self.messages:
113
- if isinstance(message, tuple) and len(message) > 1:
114
- if isinstance(message[1], Image.Image):
115
- images.append(message[1] if return_pil else message[1])
116
- return images
117
-
118
  def copy(self):
119
- new_conv = Conversation()
120
- new_conv.messages = self.messages.copy()
121
- new_conv.roles = self.roles.copy()
122
- new_conv.offset = self.offset
123
  new_conv.skip_next = self.skip_next
 
124
  return new_conv
125
 
126
- default_conversation = Conversation()
127
 
128
  # Model and processor storage
129
  tokenizer = None
@@ -131,11 +115,6 @@ model = None
131
  processor = None
132
  context_len = 8048
133
 
134
- # Helper functions
135
- def clear_conv(conv):
136
- conv.messages = []
137
- return conv
138
-
139
  def wrap_taxonomy(text):
140
  """Wraps user input with taxonomy if not already present"""
141
  if policy_v1 not in text:
@@ -158,7 +137,8 @@ def load_model(model_path):
158
  if "qwenguard" in model_path.lower():
159
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
160
  model_path,
161
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
162
  device_map="auto" if torch.cuda.is_available() else None
163
  )
164
  processor = AutoProcessor.from_pretrained(model_path)
@@ -168,7 +148,8 @@ def load_model(model_path):
168
  else:
169
  model = LlavaOnevisionForConditionalGeneration.from_pretrained(
170
  model_path,
171
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
172
  device_map="auto" if torch.cuda.is_available() else None,
173
  trust_remote_code=True
174
  )
@@ -185,10 +166,10 @@ def load_model(model_path):
185
 
186
  def get_model_list():
187
  models = [
188
- 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
189
- 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
190
  'AIML-TUDA/QwenGuard-v1.2-3B',
191
  'AIML-TUDA/QwenGuard-v1.2-7B',
 
 
192
  ]
193
  return models
194
 
@@ -204,7 +185,6 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
204
 
205
  if model is None or processor is None:
206
  return "Model not loaded. Please select a model first."
207
-
208
  try:
209
  # Check if it's a Qwen model
210
  if isinstance(model, Qwen2_5_VLForConditionalGeneration):
@@ -218,39 +198,18 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
218
  ]
219
  }
220
  ]
221
-
222
  # Process input
223
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
224
  inputs = processor(
225
- text=[text],
226
- images=[image],
227
- padding=True,
228
- return_tensors="pt"
 
229
  )
230
-
231
- # Move to GPU if available
232
- if torch.cuda.is_available():
233
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
234
-
235
- # Generate
236
- with torch.no_grad():
237
- generated_ids = model.generate(
238
- **inputs,
239
- do_sample=temperature > 0,
240
- temperature=temperature,
241
- top_p=top_p,
242
- max_new_tokens=max_tokens,
243
- )
244
-
245
- # Decode
246
- generated_ids_trimmed = [
247
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
248
- ]
249
- response = processor.batch_decode(
250
- generated_ids_trimmed,
251
- skip_special_tokens=True,
252
- clean_up_tokenization_spaces=False
253
- )[0]
254
 
255
  # Otherwise assume it's a LlavaGuard model
256
  else:
@@ -263,39 +222,37 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
263
  ],
264
  },
265
  ]
266
-
267
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
 
268
 
 
269
 
270
- # Process input for LlavaGuard models
271
- inputs = processor(text=text_prompt, images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
-
274
- # Move to GPU if available
275
- if torch.cuda.is_available():
276
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
277
-
278
- # Generate
279
- with torch.no_grad():
280
- generated_ids = model.generate(
281
- **inputs,
282
- do_sample=temperature > 0,
283
- temperature=temperature,
284
- top_p=top_p,
285
- max_new_tokens=max_tokens,
286
- )
287
-
288
- # Decode
289
- response = tokenizer.batch_decode(
290
- generated_ids[:, inputs.input_ids.shape[1]:],
291
- skip_special_tokens=True
292
- )[0]
293
-
294
  return response.strip()
295
-
296
  except Exception as e:
297
- logger.error(f"Error during inference: {e}")
298
- return f"Error during inference: {e}"
 
 
 
299
 
300
  # Gradio UI functions
301
  get_window_url_params = """
@@ -359,10 +316,17 @@ def flag_last_response(state, model_selector, request: gr.Request):
359
 
360
  def regenerate(state, image_process_mode, request: gr.Request):
361
  logger.info(f"regenerate. ip: {request.client.host}")
362
- state.messages[-1][-1] = None
363
- prev_human_msg = state.messages[-2]
364
- if type(prev_human_msg[1]) in (tuple, list):
365
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
 
 
 
 
 
 
 
366
  state.skip_next = False
367
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
368
 
@@ -378,15 +342,19 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
378
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
379
 
380
  text = wrap_taxonomy(text)
 
 
381
  if image is not None:
382
- if '<image>' not in text:
383
- text = text + '\n<image>'
384
- text = (text, image, image_process_mode)
385
  state = default_conversation.copy()
386
- state = clear_conv(state)
387
- state.append_message(state.roles[0], text)
388
- state.append_message(state.roles[1], None)
 
 
 
 
389
  state.skip_next = False
 
390
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
391
 
392
  def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
@@ -399,47 +367,50 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
399
 
400
  # Get the prompt and images
401
  prompt = state.get_prompt()
402
- all_images = state.get_images(return_pil=True)
403
 
404
  if not all_images:
405
- state.messages[-1][-1] = "Error: No image provided"
 
 
 
406
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
407
  return
408
 
409
- # Save image for logging
410
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
411
- for image, hash_val in zip(all_images, all_image_hash):
412
- t = datetime.datetime.now()
413
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash_val}.jpg")
414
- if not os.path.isfile(filename):
415
- os.makedirs(os.path.dirname(filename), exist_ok=True)
416
- image.save(filename)
417
-
418
  # Load model if needed
419
  if model is None or model_selector != getattr(model, "_name_or_path", ""):
420
  load_model(model_selector)
421
 
422
  # Run inference
423
  output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
424
- state.messages[-1][-1] = output
 
 
 
 
 
 
425
 
426
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
427
 
428
  finish_tstamp = time.time()
429
  logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
430
 
431
- with open(get_conv_log_filename(), "a") as fout:
432
- data = {
433
- "tstamp": round(finish_tstamp, 4),
434
- "type": "chat",
435
- "model": model_selector,
436
- "start": round(start_tstamp, 4),
437
- "finish": round(finish_tstamp, 4),
438
- "state": state.dict(),
439
- "images": all_image_hash,
440
- "ip": request.client.host,
441
- }
442
- fout.write(json.dumps(data) + "\n")
 
 
 
443
 
444
  # UI Components
445
  title_markdown = """
@@ -666,8 +637,9 @@ if __name__ == "__main__":
666
  ).launch(
667
  server_name=args.host,
668
  server_port=args.port,
669
- share=True
670
  )
671
  except Exception as e:
672
  logger.error(f"Error launching demo: {e}")
673
  sys.exit(1)
 
 
16
  Qwen2_5_VLForConditionalGeneration,
17
  LlavaOnevisionForConditionalGeneration
18
  )
19
+ from qwen_vl_utils import process_vision_info
20
 
21
  from taxonomy import policy_v1
22
 
 
37
 
38
  default_taxonomy = policy_v1
39
 
40
+
41
+ class SimpleConversation:
42
  def __init__(self):
43
+ self.current_prompt = ""
44
+ self.current_image = None
45
+ self.current_response = None
46
  self.skip_next = False
47
+ self.messages = [] # Add messages list to store conversation history
48
+
49
+ def set_prompt(self, prompt, image=None):
50
+ self.current_prompt = prompt
51
+ self.current_image = image
52
+ self.current_response = None
53
+ # Update messages when setting a new prompt
54
+ self.messages = [[prompt, None]]
55
 
56
+ def set_response(self, response):
57
+ self.current_response = response
58
+ # Update the last message's response when setting a response
59
+ if self.messages and len(self.messages) > 0:
60
+ self.messages[-1][-1] = response
61
+
62
+ def get_prompt(self):
63
+ if isinstance(self.current_prompt, tuple):
64
+ return self.current_prompt[0]
65
+ return self.current_prompt
66
+
67
+ def get_image(self, return_pil=False):
68
+ if self.current_image:
69
+ return [self.current_image]
70
+ if isinstance(self.current_prompt, tuple) and len(self.current_prompt) > 1:
71
+ if isinstance(self.current_prompt[1], Image.Image):
72
+ return [self.current_prompt[1]]
73
+ return None
74
 
75
  def to_gradio_chatbot(self):
76
+ if not self.messages:
77
+ return []
78
+
79
  ret = []
80
+ for msg in self.messages:
81
+ prompt = msg[0]
82
+ if isinstance(prompt, tuple) and len(prompt) > 0:
83
+ prompt = prompt[0]
84
+
85
+ if prompt and isinstance(prompt, str) and "<image>" in prompt:
86
+ prompt = prompt.replace("<image>", "")
87
+
88
+ ret.append([prompt, msg[1]])
 
 
 
 
 
 
89
  return ret
 
 
 
 
 
90
 
91
  def dict(self):
92
+ # Simplified serialization for logging
93
+ image_info = "[WITH_IMAGE]" if self.current_image is not None else "[NO_IMAGE]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return {
95
+ "prompt": self.get_prompt(),
96
+ "image": image_info,
97
+ "response": self.current_response,
98
+ "messages": [[m[0], "[RESPONSE]" if m[1] else None] for m in self.messages]
99
  }
100
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def copy(self):
102
+ new_conv = SimpleConversation()
103
+ new_conv.current_prompt = self.current_prompt
104
+ new_conv.current_image = self.current_image
105
+ new_conv.current_response = self.current_response
106
  new_conv.skip_next = self.skip_next
107
+ new_conv.messages = self.messages.copy() if self.messages else []
108
  return new_conv
109
 
110
+ default_conversation = SimpleConversation()
111
 
112
  # Model and processor storage
113
  tokenizer = None
 
115
  processor = None
116
  context_len = 8048
117
 
 
 
 
 
 
118
  def wrap_taxonomy(text):
119
  """Wraps user input with taxonomy if not already present"""
120
  if policy_v1 not in text:
 
137
  if "qwenguard" in model_path.lower():
138
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
139
  model_path,
140
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
141
+ torch_dtype="auto",
142
  device_map="auto" if torch.cuda.is_available() else None
143
  )
144
  processor = AutoProcessor.from_pretrained(model_path)
 
148
  else:
149
  model = LlavaOnevisionForConditionalGeneration.from_pretrained(
150
  model_path,
151
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
152
+ torch_dtype="auto",
153
  device_map="auto" if torch.cuda.is_available() else None,
154
  trust_remote_code=True
155
  )
 
166
 
167
  def get_model_list():
168
  models = [
 
 
169
  'AIML-TUDA/QwenGuard-v1.2-3B',
170
  'AIML-TUDA/QwenGuard-v1.2-7B',
171
+ 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
172
+ 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
173
  ]
174
  return models
175
 
 
185
 
186
  if model is None or processor is None:
187
  return "Model not loaded. Please select a model first."
 
188
  try:
189
  # Check if it's a Qwen model
190
  if isinstance(model, Qwen2_5_VLForConditionalGeneration):
 
198
  ]
199
  }
200
  ]
 
201
  # Process input
202
+ text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
203
+ image_inputs, video_inputs = process_vision_info(messages)
204
  inputs = processor(
205
+ text=[text_prompt],
206
+ images=image_inputs,
207
+ videos=video_inputs,
208
+ padding=True,
209
+ return_tensors="pt",
210
  )
211
+ inputs = inputs.to("cuda")
212
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  # Otherwise assume it's a LlavaGuard model
215
  else:
 
222
  ],
223
  },
224
  ]
 
225
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
226
+ inputs = processor(text=text_prompt, images=image, return_tensors="pt")
227
 
228
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
229
 
230
+ with torch.no_grad():
231
+ generated_ids = model.generate(
232
+ **inputs,
233
+ do_sample=temperature > 0,
234
+ temperature=temperature,
235
+ top_p=top_p,
236
+ max_new_tokens=max_tokens,
237
+ )
238
+
239
+ # Decode
240
+ generated_ids_trimmed = generated_ids[0, inputs["input_ids"].shape[1]:]
241
+ response = processor.decode(
242
+ generated_ids_trimmed,
243
+ skip_special_tokens=True,
244
+ # clean_up_tokenization_spaces=False
245
+ )
246
+ print(response)
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  return response.strip()
249
+
250
  except Exception as e:
251
+ import traceback
252
+ error_msg = f"Error during inference: {str(e)}\n{traceback.format_exc()}"
253
+ print(error_msg)
254
+ logger.error(error_msg)
255
+ return f"Error processing image. Please try again."
256
 
257
  # Gradio UI functions
258
  get_window_url_params = """
 
316
 
317
  def regenerate(state, image_process_mode, request: gr.Request):
318
  logger.info(f"regenerate. ip: {request.client.host}")
319
+ if state.messages and len(state.messages) > 0:
320
+ state.messages[-1][-1] = None
321
+ if len(state.messages) > 1:
322
+ prev_human_msg = state.messages[-2]
323
+ if isinstance(prev_human_msg[0], tuple) and len(prev_human_msg[0]) >= 2:
324
+ # Handle image process mode for previous message if it's a tuple with image
325
+ new_msg = list(prev_human_msg)
326
+ if len(prev_human_msg[0]) >= 3:
327
+ new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
328
+ state.messages[-2] = new_msg
329
+
330
  state.skip_next = False
331
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
332
 
 
342
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
343
 
344
  text = wrap_taxonomy(text)
345
+
346
+ # Reset conversation for new image-based query
347
  if image is not None:
 
 
 
348
  state = default_conversation.copy()
349
+
350
+ # Set new prompt with image
351
+ prompt = text
352
+ if image is not None:
353
+ prompt = (text, image, image_process_mode)
354
+
355
+ state.set_prompt(prompt=prompt, image=image)
356
  state.skip_next = False
357
+
358
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
359
 
360
  def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
 
367
 
368
  # Get the prompt and images
369
  prompt = state.get_prompt()
370
+ all_images = state.get_image(return_pil=True)
371
 
372
  if not all_images:
373
+ if not state.messages:
374
+ state.messages = [["Error: No image provided", None]]
375
+ else:
376
+ state.messages[-1][-1] = "Error: No image provided"
377
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
378
  return
379
 
 
 
 
 
 
 
 
 
 
380
  # Load model if needed
381
  if model is None or model_selector != getattr(model, "_name_or_path", ""):
382
  load_model(model_selector)
383
 
384
  # Run inference
385
  output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
386
+
387
+ # Update the response in the conversation state
388
+ if not state.messages:
389
+ state.messages = [[prompt, output]]
390
+ else:
391
+ state.messages[-1][-1] = output
392
+ state.current_response = output
393
 
394
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
395
 
396
  finish_tstamp = time.time()
397
  logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
398
 
399
+ try:
400
+ with open(get_conv_log_filename(), "a") as fout:
401
+ data = {
402
+ "tstamp": round(finish_tstamp, 4),
403
+ "type": "chat",
404
+ "model": model_selector,
405
+ "start": round(start_tstamp, 4),
406
+ "finish": round(finish_tstamp, 4),
407
+ "state": state.dict(),
408
+ "images": ['image'],
409
+ "ip": request.client.host,
410
+ }
411
+ fout.write(json.dumps(data) + "\n")
412
+ except Exception as e:
413
+ logger.error(f"Error writing log: {str(e)}")
414
 
415
  # UI Components
416
  title_markdown = """
 
637
  ).launch(
638
  server_name=args.host,
639
  server_port=args.port,
640
+ share=args.share
641
  )
642
  except Exception as e:
643
  logger.error(f"Error launching demo: {e}")
644
  sys.exit(1)
645
+