rabbit19731 commited on
Commit
c1582fd
·
verified ·
1 Parent(s): 8a5357d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -20
app.py CHANGED
@@ -11,10 +11,18 @@ from threading import Thread
11
  import torch
12
  import gradio as gr
13
  from transformers import AutoModelForCausalLM, TextIteratorStreamer
 
 
14
 
15
  model_name = 'AIDC-AI/Ovis2-16B'
 
16
  use_thread = False
17
 
 
 
 
 
 
18
  # load model
19
  model = AutoModelForCausalLM.from_pretrained(model_name,
20
  torch_dtype=torch.bfloat16,
@@ -49,8 +57,8 @@ def submit_chat(chatbot, text_input):
49
  return chatbot ,''
50
 
51
  @spaces.GPU
52
- def ovis_chat(chatbot: List[List[str]], image_input: Any):
53
- conversations, model_inputs = prepare_inputs(chatbot, image_input)
54
  gen_kwargs = initialize_gen_kwargs()
55
 
56
  with torch.inference_mode():
@@ -74,7 +82,7 @@ def ovis_chat(chatbot: List[List[str]], image_input: Any):
74
  log_conversation(chatbot)
75
 
76
 
77
- def prepare_inputs(chatbot: List[List[str]], image_input: Any):
78
  # conversations = [{
79
  # "from": "system",
80
  # "value": "You are a helpful assistant, and your task is to provide reliable and structured responses to users."
@@ -95,10 +103,30 @@ def prepare_inputs(chatbot: List[List[str]], image_input: Any):
95
  if conv["from"] == "human":
96
  conv["value"] = f'{image_placeholder}\n{conv["value"]}'
97
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  logger.info(conversations)
100
 
101
- prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input], max_partition=16)
102
  attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
103
 
104
  model_inputs = {
@@ -115,7 +143,7 @@ def log_conversation(chatbot):
115
  logger.info("[OVIS_CONV_END]")
116
 
117
  def clear_chat():
118
- return [], None, ""
119
 
120
  with open(f"{cur_dir}/resource/logo.svg", "r", encoding="utf-8") as svg_file:
121
  svg_content = svg_file.read()
@@ -164,26 +192,58 @@ with gr.Blocks(title=model_name.split('/')[-1], theme=gr.themes.Ocean()) as demo
164
  gr.HTML(html)
165
  with gr.Row():
166
  with gr.Column(scale=3):
167
- image_input = gr.Image(label="image", height=350, type="pil")
168
- gr.Examples(
169
- examples=[
170
- [f"{cur_dir}/examples/ovis2_math0.jpg", "Each face of the polyhedron shown is either a triangle or a square. Each square borders 4 triangles, and each triangle borders 3 squares. The polyhedron has 6 squares. How many triangles does it have?\n\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution."],
171
- [f"{cur_dir}/examples/ovis2_math1.jpg", "A large square touches another two squares, as shown in the picture. The numbers inside the smaller squares indicate their areas. What is the area of the largest square?\n\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution."],
172
- [f"{cur_dir}/examples/ovis2_figure0.png", "Explain this model."],
173
- [f"{cur_dir}/examples/ovis2_figure1.png", "Organize the notes about GRPO in the figure."],
174
- [f"{cur_dir}/examples/ovis2_multi0.jpg", "Posso avere un frappuccino e un caffè americano di taglia M? Quanto costa in totale?"],
175
- ],
176
- inputs=[image_input, text_input]
177
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with gr.Column(scale=7):
179
  chatbot = gr.Chatbot(label="Ovis", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set)
180
  text_input.render()
181
  with gr.Row():
182
  send_btn = gr.Button("Send", variant="primary")
183
  clear_btn = gr.Button("Clear", variant="secondary")
184
-
185
- send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input],chatbot)
186
- submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input],chatbot)
187
- clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  demo.launch()
 
11
  import torch
12
  import gradio as gr
13
  from transformers import AutoModelForCausalLM, TextIteratorStreamer
14
+ from moviepy.editor import VideoFileClip
15
+ from PIL import Image
16
 
17
  model_name = 'AIDC-AI/Ovis2-16B'
18
+
19
  use_thread = False
20
 
21
+ IMAGE_MAX_PARTITION = 16
22
+
23
+ VIDEO_FRAME_NUMS = 32
24
+ VIDEO_MAX_PARTITION = 1
25
+
26
  # load model
27
  model = AutoModelForCausalLM.from_pretrained(model_name,
28
  torch_dtype=torch.bfloat16,
 
57
  return chatbot ,''
58
 
59
  @spaces.GPU
60
+ def ovis_chat(chatbot: List[List[str]], image_input: Any, video_input: Any):
61
+ conversations, model_inputs = prepare_inputs(chatbot, image_input, video_input)
62
  gen_kwargs = initialize_gen_kwargs()
63
 
64
  with torch.inference_mode():
 
82
  log_conversation(chatbot)
83
 
84
 
85
+ def prepare_inputs(chatbot: List[List[str]], image_input: Any, video_input: Any):
86
  # conversations = [{
87
  # "from": "system",
88
  # "value": "You are a helpful assistant, and your task is to provide reliable and structured responses to users."
 
103
  if conv["from"] == "human":
104
  conv["value"] = f'{image_placeholder}\n{conv["value"]}'
105
  break
106
+ max_partition = IMAGE_MAX_PARTITION
107
+ image_input = [image_input]
108
+
109
+ if video_input is not None:
110
+ for conv in conversations:
111
+ if conv["from"] == "human":
112
+ conv["value"] = f'{image_placeholder}\n' * VIDEO_FRAME_NUMS + f'{conv["value"]}'
113
+ break
114
+ # extract video frames here
115
+ with VideoFileClip(video_input) as clip:
116
+ total_frames = int(clip.fps * clip.duration)
117
+ if total_frames <= VIDEO_FRAME_NUMS:
118
+ sampled_indices = range(total_frames)
119
+ else:
120
+ stride = total_frames / VIDEO_FRAME_NUMS
121
+ sampled_indices = [min(total_frames - 1, int((stride * i + stride * (i + 1)) / 2)) for i in range(VIDEO_FRAME_NUMS)]
122
+ frames = [clip.get_frame(index / clip.fps) for index in sampled_indices]
123
+ frames = [Image.fromarray(frame, mode='RGB') for frame in frames]
124
+ image_input = frames
125
+ max_partition = VIDEO_MAX_PARTITION
126
 
127
  logger.info(conversations)
128
 
129
+ prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, image_input, max_partition=max_partition)
130
  attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
131
 
132
  model_inputs = {
 
143
  logger.info("[OVIS_CONV_END]")
144
 
145
  def clear_chat():
146
+ return [], None, "", None
147
 
148
  with open(f"{cur_dir}/resource/logo.svg", "r", encoding="utf-8") as svg_file:
149
  svg_content = svg_file.read()
 
192
  gr.HTML(html)
193
  with gr.Row():
194
  with gr.Column(scale=3):
195
+ input_type = gr.Radio(choices=["image + prompt", "video + prompt"], label="Select input type:", value="image + prompt", elem_classes="my_radio")
196
+
197
+ image_input = gr.Image(label="image", height=350, type="pil", visible=True)
198
+ video_input = gr.Video(label="video", height=350, format='mp4', visible=False)
199
+ with gr.Column(visible=True) as image_examples_col:
200
+ image_examples = gr.Examples(
201
+ examples=[
202
+ [f"{cur_dir}/examples/ovis2_math0.jpg", "Each face of the polyhedron shown is either a triangle or a square. Each square borders 4 triangles, and each triangle borders 3 squares. The polyhedron has 6 squares. How many triangles does it have?\n\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution."],
203
+ [f"{cur_dir}/examples/ovis2_math1.jpg", "A large square touches another two squares, as shown in the picture. The numbers inside the smaller squares indicate their areas. What is the area of the largest square?\n\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution."],
204
+ [f"{cur_dir}/examples/ovis2_figure0.png", "Explain this model."],
205
+ [f"{cur_dir}/examples/ovis2_figure1.png", "Organize the notes about GRPO in the figure."],
206
+ [f"{cur_dir}/examples/ovis2_multi0.jpg", "Posso avere un frappuccino e un caffè americano di taglia M? Quanto costa in totale?"],
207
+ ],
208
+ inputs=[image_input, text_input]
209
+ )
210
+
211
+ def update_visibility_on_example(video_input, text_input):
212
+ return (gr.update(visible=True), text_input)
213
+
214
+ with gr.Column(visible=False) as video_examples_col:
215
+ video_examples = gr.Examples(
216
+ examples=[
217
+ [f"{cur_dir}/examples/video_demo_1.mp4", "Describe the video."]
218
+ ],
219
+ inputs=[video_input, text_input],
220
+ fn = update_visibility_on_example,
221
+ run_on_click = True,
222
+ outputs=[video_input, text_input]
223
+ )
224
+
225
  with gr.Column(scale=7):
226
  chatbot = gr.Chatbot(label="Ovis", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set)
227
  text_input.render()
228
  with gr.Row():
229
  send_btn = gr.Button("Send", variant="primary")
230
  clear_btn = gr.Button("Clear", variant="secondary")
231
+
232
+ def update_input_and_clear(selected):
233
+ if selected == "image + prompt":
234
+ visibility_updates = (gr.update(visible=True), gr.update(visible=False),
235
+ gr.update(visible=True), gr.update(visible=False))
236
+ else:
237
+ visibility_updates = (gr.update(visible=False), gr.update(visible=True),
238
+ gr.update(visible=False), gr.update(visible=True))
239
+ clear_chat_outputs = clear_chat()
240
+ return visibility_updates + clear_chat_outputs
241
+
242
+ input_type.change(fn=update_input_and_clear, inputs=input_type,
243
+ outputs=[image_input, video_input, image_examples_col, video_examples_col, chatbot, image_input, text_input, video_input])
244
+
245
+ send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input, video_input],chatbot)
246
+ submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input, video_input],chatbot)
247
+ clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input, video_input])
248
 
249
  demo.launch()