Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ from transformers import (
|
|
16 |
Qwen2_5_VLForConditionalGeneration,
|
17 |
LlavaOnevisionForConditionalGeneration
|
18 |
)
|
|
|
19 |
|
20 |
from taxonomy import policy_v1
|
21 |
|
@@ -36,94 +37,77 @@ os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
|
|
36 |
|
37 |
default_taxonomy = policy_v1
|
38 |
|
39 |
-
|
|
|
40 |
def __init__(self):
|
41 |
-
self.
|
42 |
-
self.
|
43 |
-
self.
|
44 |
self.skip_next = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
def
|
47 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def to_gradio_chatbot(self):
|
|
|
|
|
|
|
50 |
ret = []
|
51 |
-
for
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
if ret[-1][1] is None:
|
61 |
-
ret[-1][1] = message
|
62 |
-
else:
|
63 |
-
ret.append([None, message])
|
64 |
-
else:
|
65 |
-
raise ValueError(f"Invalid role: {role}")
|
66 |
return ret
|
67 |
-
|
68 |
-
def render_user_message(self, message):
|
69 |
-
if "<image>" in message:
|
70 |
-
return message.replace("<image>", "")
|
71 |
-
return message
|
72 |
|
73 |
def dict(self):
|
74 |
-
#
|
75 |
-
|
76 |
-
for role, message in self.messages:
|
77 |
-
if isinstance(message, tuple) and len(message) > 1:
|
78 |
-
# If the message contains an image (tuple format)
|
79 |
-
if isinstance(message[1], Image.Image):
|
80 |
-
# Just keep the text part and ignore the image
|
81 |
-
serialized_message = (message[0], "[IMAGE_IGNORED]")
|
82 |
-
else:
|
83 |
-
# For non-image tuples, keep as is
|
84 |
-
serialized_message = message
|
85 |
-
else:
|
86 |
-
# For non-tuple messages, keep as is
|
87 |
-
serialized_message = message
|
88 |
-
serialized_messages.append([role, serialized_message])
|
89 |
-
|
90 |
return {
|
91 |
-
"
|
92 |
-
"
|
93 |
-
"
|
94 |
-
"
|
95 |
}
|
96 |
-
|
97 |
-
def get_prompt(self):
|
98 |
-
prompt = ""
|
99 |
-
for role, message in self.messages:
|
100 |
-
if message is None:
|
101 |
-
continue
|
102 |
-
if isinstance(message, tuple):
|
103 |
-
message = message[0]
|
104 |
-
if role == self.roles[0]:
|
105 |
-
prompt += f"USER: {message}\n"
|
106 |
-
else:
|
107 |
-
prompt += f"ASSISTANT: {message}\n"
|
108 |
-
return prompt + "ASSISTANT: "
|
109 |
-
|
110 |
-
def get_images(self, return_pil=False):
|
111 |
-
images = []
|
112 |
-
for role, message in self.messages:
|
113 |
-
if isinstance(message, tuple) and len(message) > 1:
|
114 |
-
if isinstance(message[1], Image.Image):
|
115 |
-
images.append(message[1] if return_pil else message[1])
|
116 |
-
return images
|
117 |
-
|
118 |
def copy(self):
|
119 |
-
new_conv =
|
120 |
-
new_conv.
|
121 |
-
new_conv.
|
122 |
-
new_conv.
|
123 |
new_conv.skip_next = self.skip_next
|
|
|
124 |
return new_conv
|
125 |
|
126 |
-
default_conversation =
|
127 |
|
128 |
# Model and processor storage
|
129 |
tokenizer = None
|
@@ -131,11 +115,6 @@ model = None
|
|
131 |
processor = None
|
132 |
context_len = 8048
|
133 |
|
134 |
-
# Helper functions
|
135 |
-
def clear_conv(conv):
|
136 |
-
conv.messages = []
|
137 |
-
return conv
|
138 |
-
|
139 |
def wrap_taxonomy(text):
|
140 |
"""Wraps user input with taxonomy if not already present"""
|
141 |
if policy_v1 not in text:
|
@@ -158,7 +137,8 @@ def load_model(model_path):
|
|
158 |
if "qwenguard" in model_path.lower():
|
159 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
160 |
model_path,
|
161 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
|
162 |
device_map="auto" if torch.cuda.is_available() else None
|
163 |
)
|
164 |
processor = AutoProcessor.from_pretrained(model_path)
|
@@ -168,7 +148,8 @@ def load_model(model_path):
|
|
168 |
else:
|
169 |
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
170 |
model_path,
|
171 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
|
172 |
device_map="auto" if torch.cuda.is_available() else None,
|
173 |
trust_remote_code=True
|
174 |
)
|
@@ -185,10 +166,10 @@ def load_model(model_path):
|
|
185 |
|
186 |
def get_model_list():
|
187 |
models = [
|
188 |
-
'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
|
189 |
-
'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
|
190 |
'AIML-TUDA/QwenGuard-v1.2-3B',
|
191 |
'AIML-TUDA/QwenGuard-v1.2-7B',
|
|
|
|
|
192 |
]
|
193 |
return models
|
194 |
|
@@ -204,7 +185,6 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
204 |
|
205 |
if model is None or processor is None:
|
206 |
return "Model not loaded. Please select a model first."
|
207 |
-
|
208 |
try:
|
209 |
# Check if it's a Qwen model
|
210 |
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
@@ -218,39 +198,18 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
218 |
]
|
219 |
}
|
220 |
]
|
221 |
-
|
222 |
# Process input
|
223 |
-
|
|
|
224 |
inputs = processor(
|
225 |
-
text=[
|
226 |
-
images=
|
227 |
-
|
228 |
-
|
|
|
229 |
)
|
230 |
-
|
231 |
-
|
232 |
-
if torch.cuda.is_available():
|
233 |
-
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
234 |
-
|
235 |
-
# Generate
|
236 |
-
with torch.no_grad():
|
237 |
-
generated_ids = model.generate(
|
238 |
-
**inputs,
|
239 |
-
do_sample=temperature > 0,
|
240 |
-
temperature=temperature,
|
241 |
-
top_p=top_p,
|
242 |
-
max_new_tokens=max_tokens,
|
243 |
-
)
|
244 |
-
|
245 |
-
# Decode
|
246 |
-
generated_ids_trimmed = [
|
247 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
248 |
-
]
|
249 |
-
response = processor.batch_decode(
|
250 |
-
generated_ids_trimmed,
|
251 |
-
skip_special_tokens=True,
|
252 |
-
clean_up_tokenization_spaces=False
|
253 |
-
)[0]
|
254 |
|
255 |
# Otherwise assume it's a LlavaGuard model
|
256 |
else:
|
@@ -263,39 +222,37 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
263 |
],
|
264 |
},
|
265 |
]
|
266 |
-
|
267 |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
|
268 |
|
|
|
269 |
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
-
|
274 |
-
# Move to GPU if available
|
275 |
-
if torch.cuda.is_available():
|
276 |
-
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
277 |
-
|
278 |
-
# Generate
|
279 |
-
with torch.no_grad():
|
280 |
-
generated_ids = model.generate(
|
281 |
-
**inputs,
|
282 |
-
do_sample=temperature > 0,
|
283 |
-
temperature=temperature,
|
284 |
-
top_p=top_p,
|
285 |
-
max_new_tokens=max_tokens,
|
286 |
-
)
|
287 |
-
|
288 |
-
# Decode
|
289 |
-
response = tokenizer.batch_decode(
|
290 |
-
generated_ids[:, inputs.input_ids.shape[1]:],
|
291 |
-
skip_special_tokens=True
|
292 |
-
)[0]
|
293 |
-
|
294 |
return response.strip()
|
295 |
-
|
296 |
except Exception as e:
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
299 |
|
300 |
# Gradio UI functions
|
301 |
get_window_url_params = """
|
@@ -359,10 +316,17 @@ def flag_last_response(state, model_selector, request: gr.Request):
|
|
359 |
|
360 |
def regenerate(state, image_process_mode, request: gr.Request):
|
361 |
logger.info(f"regenerate. ip: {request.client.host}")
|
362 |
-
state.messages
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
state.skip_next = False
|
367 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
368 |
|
@@ -378,15 +342,19 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
378 |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
379 |
|
380 |
text = wrap_taxonomy(text)
|
|
|
|
|
381 |
if image is not None:
|
382 |
-
if '<image>' not in text:
|
383 |
-
text = text + '\n<image>'
|
384 |
-
text = (text, image, image_process_mode)
|
385 |
state = default_conversation.copy()
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
389 |
state.skip_next = False
|
|
|
390 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
391 |
|
392 |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
@@ -399,47 +367,50 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
399 |
|
400 |
# Get the prompt and images
|
401 |
prompt = state.get_prompt()
|
402 |
-
all_images = state.
|
403 |
|
404 |
if not all_images:
|
405 |
-
state.messages
|
|
|
|
|
|
|
406 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
407 |
return
|
408 |
|
409 |
-
# Save image for logging
|
410 |
-
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
411 |
-
for image, hash_val in zip(all_images, all_image_hash):
|
412 |
-
t = datetime.datetime.now()
|
413 |
-
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash_val}.jpg")
|
414 |
-
if not os.path.isfile(filename):
|
415 |
-
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
416 |
-
image.save(filename)
|
417 |
-
|
418 |
# Load model if needed
|
419 |
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
420 |
load_model(model_selector)
|
421 |
|
422 |
# Run inference
|
423 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
427 |
|
428 |
finish_tstamp = time.time()
|
429 |
logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
|
|
|
|
443 |
|
444 |
# UI Components
|
445 |
title_markdown = """
|
@@ -666,8 +637,9 @@ if __name__ == "__main__":
|
|
666 |
).launch(
|
667 |
server_name=args.host,
|
668 |
server_port=args.port,
|
669 |
-
share=
|
670 |
)
|
671 |
except Exception as e:
|
672 |
logger.error(f"Error launching demo: {e}")
|
673 |
sys.exit(1)
|
|
|
|
16 |
Qwen2_5_VLForConditionalGeneration,
|
17 |
LlavaOnevisionForConditionalGeneration
|
18 |
)
|
19 |
+
from qwen_vl_utils import process_vision_info
|
20 |
|
21 |
from taxonomy import policy_v1
|
22 |
|
|
|
37 |
|
38 |
default_taxonomy = policy_v1
|
39 |
|
40 |
+
|
41 |
+
class SimpleConversation:
|
42 |
def __init__(self):
|
43 |
+
self.current_prompt = ""
|
44 |
+
self.current_image = None
|
45 |
+
self.current_response = None
|
46 |
self.skip_next = False
|
47 |
+
self.messages = [] # Add messages list to store conversation history
|
48 |
+
|
49 |
+
def set_prompt(self, prompt, image=None):
|
50 |
+
self.current_prompt = prompt
|
51 |
+
self.current_image = image
|
52 |
+
self.current_response = None
|
53 |
+
# Update messages when setting a new prompt
|
54 |
+
self.messages = [[prompt, None]]
|
55 |
|
56 |
+
def set_response(self, response):
|
57 |
+
self.current_response = response
|
58 |
+
# Update the last message's response when setting a response
|
59 |
+
if self.messages and len(self.messages) > 0:
|
60 |
+
self.messages[-1][-1] = response
|
61 |
+
|
62 |
+
def get_prompt(self):
|
63 |
+
if isinstance(self.current_prompt, tuple):
|
64 |
+
return self.current_prompt[0]
|
65 |
+
return self.current_prompt
|
66 |
+
|
67 |
+
def get_image(self, return_pil=False):
|
68 |
+
if self.current_image:
|
69 |
+
return [self.current_image]
|
70 |
+
if isinstance(self.current_prompt, tuple) and len(self.current_prompt) > 1:
|
71 |
+
if isinstance(self.current_prompt[1], Image.Image):
|
72 |
+
return [self.current_prompt[1]]
|
73 |
+
return None
|
74 |
|
75 |
def to_gradio_chatbot(self):
|
76 |
+
if not self.messages:
|
77 |
+
return []
|
78 |
+
|
79 |
ret = []
|
80 |
+
for msg in self.messages:
|
81 |
+
prompt = msg[0]
|
82 |
+
if isinstance(prompt, tuple) and len(prompt) > 0:
|
83 |
+
prompt = prompt[0]
|
84 |
+
|
85 |
+
if prompt and isinstance(prompt, str) and "<image>" in prompt:
|
86 |
+
prompt = prompt.replace("<image>", "")
|
87 |
+
|
88 |
+
ret.append([prompt, msg[1]])
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
return ret
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
def dict(self):
|
92 |
+
# Simplified serialization for logging
|
93 |
+
image_info = "[WITH_IMAGE]" if self.current_image is not None else "[NO_IMAGE]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
return {
|
95 |
+
"prompt": self.get_prompt(),
|
96 |
+
"image": image_info,
|
97 |
+
"response": self.current_response,
|
98 |
+
"messages": [[m[0], "[RESPONSE]" if m[1] else None] for m in self.messages]
|
99 |
}
|
100 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def copy(self):
|
102 |
+
new_conv = SimpleConversation()
|
103 |
+
new_conv.current_prompt = self.current_prompt
|
104 |
+
new_conv.current_image = self.current_image
|
105 |
+
new_conv.current_response = self.current_response
|
106 |
new_conv.skip_next = self.skip_next
|
107 |
+
new_conv.messages = self.messages.copy() if self.messages else []
|
108 |
return new_conv
|
109 |
|
110 |
+
default_conversation = SimpleConversation()
|
111 |
|
112 |
# Model and processor storage
|
113 |
tokenizer = None
|
|
|
115 |
processor = None
|
116 |
context_len = 8048
|
117 |
|
|
|
|
|
|
|
|
|
|
|
118 |
def wrap_taxonomy(text):
|
119 |
"""Wraps user input with taxonomy if not already present"""
|
120 |
if policy_v1 not in text:
|
|
|
137 |
if "qwenguard" in model_path.lower():
|
138 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
139 |
model_path,
|
140 |
+
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
141 |
+
torch_dtype="auto",
|
142 |
device_map="auto" if torch.cuda.is_available() else None
|
143 |
)
|
144 |
processor = AutoProcessor.from_pretrained(model_path)
|
|
|
148 |
else:
|
149 |
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
150 |
model_path,
|
151 |
+
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
152 |
+
torch_dtype="auto",
|
153 |
device_map="auto" if torch.cuda.is_available() else None,
|
154 |
trust_remote_code=True
|
155 |
)
|
|
|
166 |
|
167 |
def get_model_list():
|
168 |
models = [
|
|
|
|
|
169 |
'AIML-TUDA/QwenGuard-v1.2-3B',
|
170 |
'AIML-TUDA/QwenGuard-v1.2-7B',
|
171 |
+
'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
|
172 |
+
'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
|
173 |
]
|
174 |
return models
|
175 |
|
|
|
185 |
|
186 |
if model is None or processor is None:
|
187 |
return "Model not loaded. Please select a model first."
|
|
|
188 |
try:
|
189 |
# Check if it's a Qwen model
|
190 |
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
|
198 |
]
|
199 |
}
|
200 |
]
|
|
|
201 |
# Process input
|
202 |
+
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
203 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
204 |
inputs = processor(
|
205 |
+
text=[text_prompt],
|
206 |
+
images=image_inputs,
|
207 |
+
videos=video_inputs,
|
208 |
+
padding=True,
|
209 |
+
return_tensors="pt",
|
210 |
)
|
211 |
+
inputs = inputs.to("cuda")
|
212 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# Otherwise assume it's a LlavaGuard model
|
215 |
else:
|
|
|
222 |
],
|
223 |
},
|
224 |
]
|
|
|
225 |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
226 |
+
inputs = processor(text=text_prompt, images=image, return_tensors="pt")
|
227 |
|
228 |
+
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
229 |
|
230 |
+
with torch.no_grad():
|
231 |
+
generated_ids = model.generate(
|
232 |
+
**inputs,
|
233 |
+
do_sample=temperature > 0,
|
234 |
+
temperature=temperature,
|
235 |
+
top_p=top_p,
|
236 |
+
max_new_tokens=max_tokens,
|
237 |
+
)
|
238 |
+
|
239 |
+
# Decode
|
240 |
+
generated_ids_trimmed = generated_ids[0, inputs["input_ids"].shape[1]:]
|
241 |
+
response = processor.decode(
|
242 |
+
generated_ids_trimmed,
|
243 |
+
skip_special_tokens=True,
|
244 |
+
# clean_up_tokenization_spaces=False
|
245 |
+
)
|
246 |
+
print(response)
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
return response.strip()
|
249 |
+
|
250 |
except Exception as e:
|
251 |
+
import traceback
|
252 |
+
error_msg = f"Error during inference: {str(e)}\n{traceback.format_exc()}"
|
253 |
+
print(error_msg)
|
254 |
+
logger.error(error_msg)
|
255 |
+
return f"Error processing image. Please try again."
|
256 |
|
257 |
# Gradio UI functions
|
258 |
get_window_url_params = """
|
|
|
316 |
|
317 |
def regenerate(state, image_process_mode, request: gr.Request):
|
318 |
logger.info(f"regenerate. ip: {request.client.host}")
|
319 |
+
if state.messages and len(state.messages) > 0:
|
320 |
+
state.messages[-1][-1] = None
|
321 |
+
if len(state.messages) > 1:
|
322 |
+
prev_human_msg = state.messages[-2]
|
323 |
+
if isinstance(prev_human_msg[0], tuple) and len(prev_human_msg[0]) >= 2:
|
324 |
+
# Handle image process mode for previous message if it's a tuple with image
|
325 |
+
new_msg = list(prev_human_msg)
|
326 |
+
if len(prev_human_msg[0]) >= 3:
|
327 |
+
new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
|
328 |
+
state.messages[-2] = new_msg
|
329 |
+
|
330 |
state.skip_next = False
|
331 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
332 |
|
|
|
342 |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
343 |
|
344 |
text = wrap_taxonomy(text)
|
345 |
+
|
346 |
+
# Reset conversation for new image-based query
|
347 |
if image is not None:
|
|
|
|
|
|
|
348 |
state = default_conversation.copy()
|
349 |
+
|
350 |
+
# Set new prompt with image
|
351 |
+
prompt = text
|
352 |
+
if image is not None:
|
353 |
+
prompt = (text, image, image_process_mode)
|
354 |
+
|
355 |
+
state.set_prompt(prompt=prompt, image=image)
|
356 |
state.skip_next = False
|
357 |
+
|
358 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
359 |
|
360 |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
|
|
367 |
|
368 |
# Get the prompt and images
|
369 |
prompt = state.get_prompt()
|
370 |
+
all_images = state.get_image(return_pil=True)
|
371 |
|
372 |
if not all_images:
|
373 |
+
if not state.messages:
|
374 |
+
state.messages = [["Error: No image provided", None]]
|
375 |
+
else:
|
376 |
+
state.messages[-1][-1] = "Error: No image provided"
|
377 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
378 |
return
|
379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
# Load model if needed
|
381 |
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
382 |
load_model(model_selector)
|
383 |
|
384 |
# Run inference
|
385 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
386 |
+
|
387 |
+
# Update the response in the conversation state
|
388 |
+
if not state.messages:
|
389 |
+
state.messages = [[prompt, output]]
|
390 |
+
else:
|
391 |
+
state.messages[-1][-1] = output
|
392 |
+
state.current_response = output
|
393 |
|
394 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
395 |
|
396 |
finish_tstamp = time.time()
|
397 |
logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
|
398 |
|
399 |
+
try:
|
400 |
+
with open(get_conv_log_filename(), "a") as fout:
|
401 |
+
data = {
|
402 |
+
"tstamp": round(finish_tstamp, 4),
|
403 |
+
"type": "chat",
|
404 |
+
"model": model_selector,
|
405 |
+
"start": round(start_tstamp, 4),
|
406 |
+
"finish": round(finish_tstamp, 4),
|
407 |
+
"state": state.dict(),
|
408 |
+
"images": ['image'],
|
409 |
+
"ip": request.client.host,
|
410 |
+
}
|
411 |
+
fout.write(json.dumps(data) + "\n")
|
412 |
+
except Exception as e:
|
413 |
+
logger.error(f"Error writing log: {str(e)}")
|
414 |
|
415 |
# UI Components
|
416 |
title_markdown = """
|
|
|
637 |
).launch(
|
638 |
server_name=args.host,
|
639 |
server_port=args.port,
|
640 |
+
share=args.share
|
641 |
)
|
642 |
except Exception as e:
|
643 |
logger.error(f"Error launching demo: {e}")
|
644 |
sys.exit(1)
|
645 |
+
|