BenkHel commited on
Commit
c2b8ea8
·
verified ·
1 Parent(s): 2d8021a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -175
app.py CHANGED
@@ -1,42 +1,35 @@
 
1
  import subprocess
2
  import sys
3
  import os
4
-
5
  from transformers import TextIteratorStreamer
6
  import argparse
7
  import time
8
  import subprocess
9
  import spaces
10
  import cumo.serve.gradio_web_server as gws
11
-
12
  from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
13
-
14
  import datetime
15
  import json
16
-
17
  import gradio as gr
18
  import requests
19
  from PIL import Image
20
-
21
  from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle)
22
  from cumo.constants import LOGDIR
23
  from cumo.model.language_model.llava_mistral import LlavaMistralForCausalLM
24
  from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
25
  import hashlib
26
-
27
  import torch
28
  import io
29
  from cumo.constants import WORKER_HEART_BEAT_INTERVAL
30
- from cumo.utils import (build_logger, server_error_msg,
31
- pretty_print_semaphore)
32
  from cumo.model.builder import load_pretrained_model
33
  from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
34
  from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
35
- from transformers import TextIteratorStreamer
36
  from threading import Thread
37
 
 
38
  headers = {"User-Agent": "CuMo"}
39
-
40
  no_change_btn = gr.Button()
41
  enable_btn = gr.Button(interactive=True)
42
  disable_btn = gr.Button(interactive=False)
@@ -54,9 +47,10 @@ tokenizer, model, image_processor, context_len = load_pretrained_model(
54
  )
55
  model.config.training = False
56
 
57
- # FIXED PROMPT
58
  FIXED_PROMPT = "<image>\nWhat material is this item and how to dispose of it?"
59
 
 
60
  def clear_history():
61
  state = default_conversation.copy()
62
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
@@ -64,27 +58,12 @@ def clear_history():
64
  def add_text(state, imagebox, textbox, image_process_mode):
65
  if state is None:
66
  state = conv_templates[conv_mode].copy()
67
-
68
  if imagebox is not None:
69
- try:
70
- image = Image.open(imagebox).convert('RGB')
71
- except Exception as e:
72
- print(f"Failed to load image: {e}")
73
- yield (state, state.to_gradio_chatbot(), "⚠️ Could not load example image.", None) + (enable_btn,) * 5
74
- return
75
-
76
- textbox = DEFAULT_IMAGE_TOKEN + "\nWhat material is this item and how to dispose of it?"
77
- textbox = (textbox, image, image_process_mode)
78
-
79
- else:
80
- yield (state, state.to_gradio_chatbot(), "⚠️ Please upload or select an image first.", None) + (enable_btn,) * 5
81
- return
82
-
83
- state.append_message(state.roles[0], textbox)
84
- state.append_message(state.roles[1], None)
85
-
86
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
87
-
88
 
89
  def delete_text(state, image_process_mode):
90
  state.messages[-1][-1] = None
@@ -93,63 +72,42 @@ def delete_text(state, image_process_mode):
93
  prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
94
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
95
 
96
- def regenerate(state, image_process_mode):
97
- state.messages[-1][-1] = None
98
- prev_human_msg = state.messages[-2]
99
- if type(prev_human_msg[1]) in (tuple, list):
100
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
101
- state.skip_next = False
102
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
103
-
104
  @spaces.GPU
105
  def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
106
- prompt = FIXED_PROMPT # <-- Hier fest!
107
  images = state.get_images(return_pil=True)
108
-
109
  ori_prompt = prompt
110
  num_image_tokens = 0
111
 
112
- if images is not None and len(images) > 0:
113
- if len(images) > 0:
114
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
115
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
116
- image_sizes = [image.size for image in images]
117
- images = process_images(images, image_processor, model.config)
118
-
119
- if type(images) is list:
120
- images = [image.to(model.device, dtype=torch.float16) for image in images]
121
- else:
122
- images = images.to(model.device, dtype=torch.float16)
123
-
124
- replace_token = DEFAULT_IMAGE_TOKEN
125
- if getattr(model.config, 'mm_use_im_start_end', False):
126
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
127
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
128
- num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
129
  else:
130
- images = None
131
- image_sizes = None
 
 
 
 
132
  image_args = {"images": images, "image_sizes": image_sizes}
133
  else:
134
- images = None
135
  image_args = {}
136
 
137
  max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
138
- max_new_tokens = 512
139
- do_sample = True if temperature > 0.001 else False
140
- stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
141
-
142
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
143
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
144
-
145
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
146
  if max_new_tokens < 1:
147
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
148
  return
149
 
 
150
  thread = Thread(target=model.generate, kwargs=dict(
151
  inputs=input_ids,
152
- do_sample=do_sample,
153
  temperature=temperature,
154
  top_p=top_p,
155
  max_new_tokens=max_new_tokens,
@@ -160,6 +118,8 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
160
  ))
161
  thread.start()
162
  generated_text = ''
 
 
163
  for new_text in streamer:
164
  generated_text += new_text
165
  if generated_text.endswith(stop_str):
@@ -169,51 +129,23 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
169
  yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
170
  torch.cuda.empty_cache()
171
 
172
- title_markdown = ("""
173
- # CuMo: Trained for waste management
174
- """)
175
-
176
- tos_markdown = ("""
177
- ### Please "🗑️ Clear" the output before offering a new picture!
178
- ### Source and Terms of use
179
- This demo is based on the original CuMo project by SHI-Labs ([GitHub](https://github.com/SHI-Labs/CuMo)).
180
- If you use this service or build upon this work, please cite the original publication:
181
- Li, Jiachen and Wang, Xinyao and Zhu, Sijie and Kuo, Chia-wen and Xu, Lu and Chen, Fan and Jain, Jitesh and Shi, Humphrey and Wen, Longyin.
182
- CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts. arXiv preprint, 2024.
183
- [[arXiv](https://arxiv.org/abs/2405.05949)]
184
-
185
- By using this service, users are required to agree to the following terms:
186
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
187
-
188
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
189
- """)
190
-
191
-
192
-
193
- learn_more_markdown = ("""
194
- ### License
195
- The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
196
- """)
197
-
198
- block_css = """
199
- #buttons button {
200
- min-width: min(120px,100%);
201
- }
202
- """
203
-
204
-
205
-
206
  textbox = gr.Textbox(
207
  show_label=False,
208
- placeholder="Prompt is fixed: What material is this item and how to dispose of it?",
209
  container=False,
210
  interactive=False
211
  )
212
 
213
- with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
 
 
 
 
214
  state = gr.State()
215
 
216
- gr.Markdown(title_markdown)
 
217
 
218
  with gr.Row():
219
  with gr.Column(scale=3):
@@ -223,36 +155,27 @@ with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
223
  value="Default",
224
  label="Preprocess for non-square image", visible=False)
225
 
226
-
227
- #cur_dir = os.path.dirname(os.path.abspath(__file__))
228
  cur_dir = './cumo/serve'
229
- default_prompt = "<image>\nWhat material is this item and how to dispose of it?"
230
  gr.Examples(examples=[
231
- [f"{cur_dir}/examples/0165 CB.jpg", default_prompt],
232
- [f"{cur_dir}/examples/0225 PA.jpg", default_prompt],
233
- [f"{cur_dir}/examples/0787 GM.jpg", default_prompt],
234
- [f"{cur_dir}/examples/1396 A.jpg", default_prompt],
235
- [f"{cur_dir}/examples/2001 P.jpg", default_prompt],
236
- [f"{cur_dir}/examples/2658 PE.jpg", default_prompt],
237
- [f"{cur_dir}/examples/3113 R.jpg", default_prompt],
238
- [f"{cur_dir}/examples/3750 RPC.jpg", default_prompt],
239
- [f"{cur_dir}/examples/5033 CC.jpg", default_prompt],
240
- [f"{cur_dir}/examples/5307 B.jpg", default_prompt],
241
- ], inputs=[imagebox, textbox], cache_examples=False)
242
-
243
-
244
- with gr.Accordion("Parameters", open=False) as parameter_row:
245
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
246
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
247
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
248
 
249
  with gr.Column(scale=8):
250
- chatbot = gr.Chatbot(
251
- elem_id="chatbot",
252
- label="CuMo Chatbot",
253
- height=650,
254
- layout="panel",
255
- )
256
  with gr.Row():
257
  with gr.Column(scale=8):
258
  textbox.render()
@@ -263,50 +186,18 @@ with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
263
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
264
  clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
265
 
266
-
267
  gr.Markdown(tos_markdown)
268
  gr.Markdown(learn_more_markdown)
269
  url_params = gr.JSON(visible=False)
270
 
271
- # Register listeners
272
  btn_list = [regenerate_btn, clear_btn]
273
- clear_btn.click(
274
- clear_history,
275
- None,
276
- [state, chatbot, textbox, imagebox] + btn_list,
277
- queue=False
278
- )
279
-
280
- regenerate_btn.click(
281
- delete_text,
282
- [state, image_process_mode],
283
- [state, chatbot, textbox, imagebox] + btn_list,
284
- ).then(
285
- generate,
286
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
287
- [state, chatbot, textbox, imagebox] + btn_list,
288
- )
289
- textbox.submit(
290
- add_text,
291
- [state, imagebox, textbox, image_process_mode],
292
- [state, chatbot, textbox, imagebox] + btn_list,
293
- ).then(
294
- generate,
295
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
296
- [state, chatbot, textbox, imagebox] + btn_list,
297
- )
298
-
299
- submit_btn.click(
300
- add_text,
301
- [state, imagebox, textbox, image_process_mode],
302
- [state, chatbot, textbox, imagebox] + btn_list,
303
- ).then(
304
- generate,
305
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
306
- [state, chatbot, textbox, imagebox] + btn_list,
307
- )
308
-
309
- demo.queue(
310
- status_update_rate=10,
311
- api_open=False
312
- ).launch()
 
1
+ # --- Imports bleiben unverändert ---
2
  import subprocess
3
  import sys
4
  import os
 
5
  from transformers import TextIteratorStreamer
6
  import argparse
7
  import time
8
  import subprocess
9
  import spaces
10
  import cumo.serve.gradio_web_server as gws
 
11
  from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
 
12
  import datetime
13
  import json
 
14
  import gradio as gr
15
  import requests
16
  from PIL import Image
 
17
  from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle)
18
  from cumo.constants import LOGDIR
19
  from cumo.model.language_model.llava_mistral import LlavaMistralForCausalLM
20
  from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
21
  import hashlib
 
22
  import torch
23
  import io
24
  from cumo.constants import WORKER_HEART_BEAT_INTERVAL
25
+ from cumo.utils import (build_logger, server_error_msg, pretty_print_semaphore)
 
26
  from cumo.model.builder import load_pretrained_model
27
  from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
28
  from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
 
29
  from threading import Thread
30
 
31
+ # --- Model Setup ---
32
  headers = {"User-Agent": "CuMo"}
 
33
  no_change_btn = gr.Button()
34
  enable_btn = gr.Button(interactive=True)
35
  disable_btn = gr.Button(interactive=False)
 
47
  )
48
  model.config.training = False
49
 
50
+ # --- Prompt ---
51
  FIXED_PROMPT = "<image>\nWhat material is this item and how to dispose of it?"
52
 
53
+ # --- Functions ---
54
  def clear_history():
55
  state = default_conversation.copy()
56
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
 
58
  def add_text(state, imagebox, textbox, image_process_mode):
59
  if state is None:
60
  state = conv_templates[conv_mode].copy()
 
61
  if imagebox is not None:
62
+ image = Image.open(imagebox).convert('RGB')
63
+ textbox = (FIXED_PROMPT, image, image_process_mode)
64
+ state.append_message(state.roles[0], textbox)
65
+ state.append_message(state.roles[1], None)
66
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def delete_text(state, image_process_mode):
69
  state.messages[-1][-1] = None
 
72
  prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
73
  yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
74
 
 
 
 
 
 
 
 
 
75
  @spaces.GPU
76
  def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
77
+ prompt = FIXED_PROMPT
78
  images = state.get_images(return_pil=True)
 
79
  ori_prompt = prompt
80
  num_image_tokens = 0
81
 
82
+ if images and len(images) > 0:
83
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
84
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
85
+ image_sizes = [image.size for image in images]
86
+ images = process_images(images, image_processor, model.config)
87
+ if isinstance(images, list):
88
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
 
 
 
 
 
 
 
 
 
 
89
  else:
90
+ images = images.to(model.device, dtype=torch.float16)
91
+ replace_token = DEFAULT_IMAGE_TOKEN
92
+ if getattr(model.config, 'mm_use_im_start_end', False):
93
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
94
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
95
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
96
  image_args = {"images": images, "image_sizes": image_sizes}
97
  else:
 
98
  image_args = {}
99
 
100
  max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
 
 
 
 
101
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
102
+ max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
 
 
103
  if max_new_tokens < 1:
104
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation.", "error_code": 0}).encode() + b"\0"
105
  return
106
 
107
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
108
  thread = Thread(target=model.generate, kwargs=dict(
109
  inputs=input_ids,
110
+ do_sample=(temperature > 0.001),
111
  temperature=temperature,
112
  top_p=top_p,
113
  max_new_tokens=max_new_tokens,
 
118
  ))
119
  thread.start()
120
  generated_text = ''
121
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
122
+
123
  for new_text in streamer:
124
  generated_text += new_text
125
  if generated_text.endswith(stop_str):
 
129
  yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
130
  torch.cuda.empty_cache()
131
 
132
+ # --- UI Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  textbox = gr.Textbox(
134
  show_label=False,
135
+ placeholder="Prompt is fixed: What material is this item and how to dispose of it.",
136
  container=False,
137
  interactive=False
138
  )
139
 
140
+ with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css="""
141
+ #buttons button {
142
+ min-width: min(120px,100%);
143
+ }
144
+ """) as demo:
145
  state = gr.State()
146
 
147
+ gr.Markdown("# CuMo: Trained for waste management")
148
+ gr.Markdown(f"**Prompt:** `{FIXED_PROMPT}`")
149
 
150
  with gr.Row():
151
  with gr.Column(scale=3):
 
155
  value="Default",
156
  label="Preprocess for non-square image", visible=False)
157
 
 
 
158
  cur_dir = './cumo/serve'
 
159
  gr.Examples(examples=[
160
+ [f"{cur_dir}/examples/0165 CB.jpg"],
161
+ [f"{cur_dir}/examples/0225 PA.jpg"],
162
+ [f"{cur_dir}/examples/0787 GM.jpg"],
163
+ [f"{cur_dir}/examples/1396 A.jpg"],
164
+ [f"{cur_dir}/examples/2001 P.jpg"],
165
+ [f"{cur_dir}/examples/2658 PE.jpg"],
166
+ [f"{cur_dir}/examples/3113 R.jpg"],
167
+ [f"{cur_dir}/examples/3750 RPC.jpg"],
168
+ [f"{cur_dir}/examples/5033 CC.jpg"],
169
+ [f"{cur_dir}/examples/5307 B.jpg"],
170
+ ], inputs=[imagebox], cache_examples=False)
171
+
172
+ with gr.Accordion("Parameters", open=False):
173
+ temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.1, interactive=True, label="Temperature")
174
+ top_p = gr.Slider(0.0, 1.0, value=0.7, step=0.1, interactive=True, label="Top P")
175
+ max_output_tokens = gr.Slider(0, 1024, value=512, step=64, interactive=True, label="Max output tokens")
 
176
 
177
  with gr.Column(scale=8):
178
+ chatbot = gr.Chatbot(elem_id="chatbot", label="CuMo Chatbot", height=650, layout="panel")
 
 
 
 
 
179
  with gr.Row():
180
  with gr.Column(scale=8):
181
  textbox.render()
 
186
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
187
  clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
188
 
 
189
  gr.Markdown(tos_markdown)
190
  gr.Markdown(learn_more_markdown)
191
  url_params = gr.JSON(visible=False)
192
 
193
+ # --- Event Bindings ---
194
  btn_list = [regenerate_btn, clear_btn]
195
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
196
+ regenerate_btn.click(delete_text, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
197
+ ).then(generate, [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], [state, chatbot, textbox, imagebox] + btn_list)
198
+ textbox.submit(add_text, [state, imagebox, textbox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
199
+ ).then(generate, [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], [state, chatbot, textbox, imagebox] + btn_list)
200
+ submit_btn.click(add_text, [state, imagebox, textbox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
201
+ ).then(generate, [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], [state, chatbot, textbox, imagebox] + btn_list)
202
+
203
+ demo.queue(status_update_rate=10, api_open=False).launch()