merve HF staff commited on
Commit
7b92c9b
·
verified ·
1 Parent(s): 919942b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -18
app.py CHANGED
@@ -6,6 +6,7 @@ from threading import Thread
6
  import gradio as gr
7
  import spaces
8
  import torch
 
9
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
10
 
11
  model_id = "google/gemma-3-12b-it"
@@ -14,9 +15,69 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
14
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def process_new_user_message(message: dict) -> list[dict]:
19
- return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def process_history(history: list[dict]) -> list[dict]:
@@ -71,32 +132,26 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
71
  examples = [
72
  [
73
  {
74
- "text": "caption this image",
75
- "files": ["assets/sample-images/01.png"],
76
  }
77
  ],
78
  [
79
  {
80
- "text": "What's the sign says?",
81
- "files": ["assets/sample-images/02.png"],
82
  }
83
  ],
84
  [
85
  {
86
- "text": "Compare and contrast the two images.",
87
- "files": ["assets/sample-images/03.png"],
88
  }
89
  ],
90
  [
91
  {
92
- "text": "List all the objects in the image and their colors.",
93
- "files": ["assets/sample-images/04.png"],
94
- }
95
- ],
96
- [
97
- {
98
- "text": "Describe the atmosphere of the scene.",
99
- "files": ["assets/sample-images/05.png"],
100
  }
101
  ],
102
  [
@@ -164,20 +219,50 @@ examples = [
164
  "files": ["assets/additional-examples/4.png"],
165
  }
166
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  ]
168
 
169
  demo = gr.ChatInterface(
170
  fn=run,
171
  type="messages",
172
- textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"),
173
  multimodal=True,
174
  additional_inputs=[
175
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
176
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=500),
177
  ],
178
  stop_btn=False,
179
  title="Gemma 3 12B it",
180
- description="<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />",
181
  examples=examples,
182
  run_examples_on_click=False,
183
  cache_examples=False,
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ import re
10
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
11
 
12
  model_id = "google/gemma-3-12b-it"
 
15
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
16
  )
17
 
18
+ import cv2
19
+ from PIL import Image
20
+ import numpy as np
21
+ import tempfile
22
+
23
+ def downsample_video(video_path):
24
+ vidcap = cv2.VideoCapture(video_path)
25
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
26
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
27
+
28
+ frame_interval = int(fps / 3)
29
+ frames = []
30
+
31
+ for i in range(0, total_frames, frame_interval):
32
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
33
+ success, image = vidcap.read()
34
+ if success:
35
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
36
+ pil_image = Image.fromarray(image)
37
+ timestamp = round(i / fps, 2)
38
+ frames.append((pil_image, timestamp))
39
+
40
+ vidcap.release()
41
+ return frames
42
+
43
 
44
  def process_new_user_message(message: dict) -> list[dict]:
45
+ if message["files"]:
46
+ if "<image>" in message["text"]:
47
+ content = []
48
+ print("message[files]", message["files"])
49
+ parts = re.split(r'(<image>)', message["text"])
50
+ image_index = 0
51
+ print("parts", parts)
52
+ for part in parts:
53
+ print("part", part)
54
+ if part == "<image>":
55
+ content.append({"type": "image", "url": message["files"][image_index]})
56
+ print("file", message["files"][image_index])
57
+ image_index += 1
58
+ elif part.strip():
59
+ content.append({"type": "text", "text": part.strip()})
60
+ elif isinstance(part, str) and not part == "<image>":
61
+ content.append({"type": "text", "text": part})
62
+ print(content)
63
+ return content
64
+ elif message["files"][0].endswith(".mp4"):
65
+ content = []
66
+ video = message["files"].pop(0)
67
+ frames = downsample_video(video)
68
+ for frame in frames:
69
+ pil_image, timestamp = frame
70
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
71
+ pil_image.save(temp_file.name)
72
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
73
+ content.append({"type": "image", "url": temp_file.name})
74
+ print(content)
75
+ return content
76
+ else:
77
+ # non interleaved images
78
+ return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]]
79
+ else:
80
+ return [{"type": "text", "text": message["text"]}]
81
 
82
 
83
  def process_history(history: list[dict]) -> list[dict]:
 
132
  examples = [
133
  [
134
  {
135
+ "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.",
136
+ "files": [],
137
  }
138
  ],
139
  [
140
  {
141
+ "text": "Write the matplotlib code to generate the same bar chart.",
142
+ "files": ["assets/sample-images/barchart.png"],
143
  }
144
  ],
145
  [
146
  {
147
+ "text": "What is odd about this video?",
148
+ "files": ["assets/sample-images/tmp.mp4"],
149
  }
150
  ],
151
  [
152
  {
153
+ "text": "I already have this supplement <image> and I want to buy this one <image>. Do they have known interactions?",
154
+ "files": ["assets/sample-images/pill1.png", "assets/sample-images/pill2.png"],
 
 
 
 
 
 
155
  }
156
  ],
157
  [
 
219
  "files": ["assets/additional-examples/4.png"],
220
  }
221
  ],
222
+ [
223
+ {
224
+ "text": "caption this image",
225
+ "files": ["assets/sample-images/01.png"],
226
+ }
227
+ ],
228
+ [
229
+ {
230
+ "text": "What's the sign says?",
231
+ "files": ["assets/sample-images/02.png"],
232
+ }
233
+ ],
234
+ [
235
+ {
236
+ "text": "Compare and contrast the two images.",
237
+ "files": ["assets/sample-images/03.png"],
238
+ }
239
+ ],
240
+ [
241
+ {
242
+ "text": "List all the objects in the image and their colors.",
243
+ "files": ["assets/sample-images/04.png"],
244
+ }
245
+ ],
246
+ [
247
+ {
248
+ "text": "Describe the atmosphere of the scene.",
249
+ "files": ["assets/sample-images/05.png"],
250
+ }
251
+ ],
252
  ]
253
 
254
  demo = gr.ChatInterface(
255
  fn=run,
256
  type="messages",
257
+ textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple"),
258
  multimodal=True,
259
  additional_inputs=[
260
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
261
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
262
  ],
263
  stop_btn=False,
264
  title="Gemma 3 12B it",
265
+ description="<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' /><br>This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks. You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.",
266
  examples=examples,
267
  run_examples_on_click=False,
268
  cache_examples=False,