prithivMLmods commited on
Commit
91d2c01
·
verified ·
1 Parent(s): 1f9408f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -224
app.py CHANGED
@@ -1,239 +1,349 @@
1
- import gradio as gr
 
 
 
 
 
2
  from threading import Thread
3
- import time
4
- from PIL import Image
5
- import torch
6
- import spaces
7
  import cv2
8
- import numpy as np
9
- from transformers import (
10
- AutoProcessor,
11
- AutoModelForVision2Seq,
12
- TextIteratorStreamer
 
 
 
 
 
 
13
  )
14
- # Helper function to return a progress bar HTML snippet.
15
- def progress_bar_html(label: str) -> str:
16
- return f'''
17
- <div style="display: flex; align-items: center;">
18
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
19
- <div style="width: 110px; height: 5px; background-color: #FFB6C1; border-radius: 2px; overflow: hidden;">
20
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
21
- </div>
22
- </div>
23
- <style>
24
- @keyframes loading {{
25
- 0% {{ transform: translateX(-100%); }}
26
- 100% {{ transform: translateX(100%); }}
27
- }}
28
- </style>
29
- '''
30
-
31
- #adding examples
32
- examples=[
33
- [{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
34
- [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
35
- [{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
36
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
37
- [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
38
- [{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
39
- [{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
40
- [{"text": "Transcription of the letter", "files": ["examples/222.png"]}],
41
 
42
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Helper: Downsample video to extract a fixed number of frames.
45
- def downsample_video(video_path, num_frames=22):
46
- cap = cv2.VideoCapture(video_path)
47
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
- fps = cap.get(cv2.CAP_PROP_FPS)
49
- # Calculate evenly spaced frame indices.
50
- frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  frames = []
52
- for idx in frame_indices:
53
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
54
- ret, frame = cap.read()
55
- if ret:
56
- # Convert BGR to RGB and then to a PIL image.
57
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
58
- frame = Image.fromarray(frame)
59
- frames.append(frame)
60
- cap.release()
 
 
61
  return frames
62
 
63
- # Load processor and model.
64
- processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
65
- model = AutoModelForVision2Seq.from_pretrained(
66
- "HuggingFaceTB/SmolVLM-Instruct",
67
- torch_dtype=torch.bfloat16,
68
- ).to("cuda")
69
-
70
- @spaces.GPU
71
- def model_inference(
72
- input_dict, history, decoding_strategy, temperature, max_new_tokens,
73
- repetition_penalty, top_p
74
- ):
75
- text = input_dict["text"]
76
-
77
- # --- Video Inference Branch ---
78
- if text.lower().startswith("@video-infer"):
79
- # Remove the command prefix to get the prompt.
80
- prompt_text = text[len("@video-infer"):].strip()
81
- if not input_dict["files"]:
82
- yield "Error: Please provide a video file for @video-infer."
83
- return
84
- # Assume the first file is a video.
85
- video_file = input_dict["files"][0]
86
- frames = downsample_video(video_file)
87
- if not frames:
88
- yield "Error: Could not extract frames from the video."
89
- return
90
- # Build a chat content: include the user prompt and then each frame labeled.
91
- content = [{"type": "text", "text": prompt_text}]
92
- for idx, frame in enumerate(frames):
93
- content.append({"type": "text", "text": f"Frame {idx+1}:"})
94
- content.append({"type": "image", "image": frame})
95
- resulting_messages = [{
96
- "role": "user",
97
- "content": content
98
- }]
99
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
100
- # Process the extracted frames as images.
101
- inputs = processor(text=prompt, images=[frames], return_tensors="pt")
102
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
103
-
104
- # Setup generation parameters.
105
- generation_args = {
106
- "max_new_tokens": max_new_tokens,
107
- "repetition_penalty": repetition_penalty,
108
- }
109
- assert decoding_strategy in ["Greedy", "Top P Sampling"]
110
- if decoding_strategy == "Greedy":
111
- generation_args["do_sample"] = False
112
- elif decoding_strategy == "Top P Sampling":
113
- generation_args["temperature"] = temperature
114
- generation_args["do_sample"] = True
115
- generation_args["top_p"] = top_p
116
-
117
- generation_args.update(inputs)
118
-
119
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
120
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
121
- buffer = ""
122
- thread = Thread(target=model.generate, kwargs=generation_args)
123
- thread.start()
124
- yield progress_bar_html("Processing Video with SmolVLM")
125
- for new_text in streamer:
126
- buffer += new_text
127
- time.sleep(0.01)
128
- yield buffer
 
 
 
 
 
 
 
129
  return
130
 
131
- # --- Default Image Inference Branch ---
132
- # Process input images if provided.
133
- if len(input_dict["files"]) > 1:
134
- images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
135
- elif len(input_dict["files"]) == 1:
136
- images = [Image.open(input_dict["files"][0]).convert("RGB")]
137
- else:
138
- images = []
139
-
140
- # Validate input.
141
- if text == "" and not images:
142
- gr.Error("Please input a query and optionally image(s).")
143
- if text == "" and images:
144
- gr.Error("Please input a text query along with the image(s).")
145
-
146
- resulting_messages = [{
147
- "role": "user",
148
- "content": [{"type": "image"} for _ in range(len(images))] + [
149
- {"type": "text", "text": text}
150
- ]
151
- }]
152
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
153
- inputs = processor(text=prompt, images=[images], return_tensors="pt")
154
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
155
-
156
- generation_args = {
157
- "max_new_tokens": max_new_tokens,
158
- "repetition_penalty": repetition_penalty,
159
- }
160
- assert decoding_strategy in ["Greedy", "Top P Sampling"]
161
- if decoding_strategy == "Greedy":
162
- generation_args["do_sample"] = False
163
- elif decoding_strategy == "Top P Sampling":
164
- generation_args["temperature"] = temperature
165
- generation_args["do_sample"] = True
166
- generation_args["top_p"] = top_p
167
-
168
- generation_args.update(inputs)
169
-
170
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
171
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
172
- buffer = ""
173
- thread = Thread(target=model.generate, kwargs=generation_args)
174
- thread.start()
175
- yield progress_bar_html("Processing Video with SmolVLM")
176
- for new_text in streamer:
177
- buffer += new_text
178
- time.sleep(0.01)
179
- yield buffer
180
-
181
- # Gradio ChatInterface: Allow both image and video file types.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  demo = gr.ChatInterface(
183
- fn=model_inference,
184
- description="# **SmolVLM Video Infer `@video-infer for video understanding`**",
185
- examples=examples,
186
- textbox=gr.MultimodalTextbox(
187
- label="Query Input",
188
- file_types=["image", "video"],
189
- file_count="multiple"
190
- ),
191
- stop_btn="Stop Generation",
192
  multimodal=True,
193
  additional_inputs=[
194
- gr.Radio(
195
- ["Top P Sampling", "Greedy"],
196
- value="Greedy",
197
- label="Decoding strategy",
198
- info="Higher values is equivalent to sampling more low-probability tokens.",
199
- ),
200
- gr.Slider(
201
- minimum=0.0,
202
- maximum=5.0,
203
- value=0.4,
204
- step=0.1,
205
- interactive=True,
206
- label="Sampling temperature",
207
- info="Higher values will produce more diverse outputs.",
208
- ),
209
- gr.Slider(
210
- minimum=8,
211
- maximum=1024,
212
- value=512,
213
- step=1,
214
- interactive=True,
215
- label="Maximum number of new tokens to generate",
216
- ),
217
- gr.Slider(
218
- minimum=0.01,
219
- maximum=5.0,
220
- value=1.2,
221
- step=0.01,
222
- interactive=True,
223
- label="Repetition penalty",
224
- info="1.0 is equivalent to no penalty",
225
- ),
226
- gr.Slider(
227
- minimum=0.01,
228
- maximum=0.99,
229
- value=0.8,
230
- step=0.01,
231
- interactive=True,
232
- label="Top P",
233
- info="Higher values is equivalent to sampling more low-probability tokens.",
234
- )
235
- ],
236
- cache_examples=False
237
  )
238
 
239
- demo.launch(debug=True)
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import tempfile
6
+ from collections.abc import Iterator
7
  from threading import Thread
8
+
 
 
 
9
  import cv2
10
+ import gradio as gr
11
+ import spaces
12
+ import torch
13
+ from loguru import logger
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
+
17
+ model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
18
+ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
+ model = Gemma3ForConditionalGeneration.from_pretrained(
20
+ model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
+
25
+
26
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
+ image_count = 0
28
+ video_count = 0
29
+ for path in paths:
30
+ if path.endswith(".mp4"):
31
+ video_count += 1
32
+ else:
33
+ image_count += 1
34
+ return image_count, video_count
35
+
36
+
37
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
+ image_count = 0
39
+ video_count = 0
40
+ for item in history:
41
+ if item["role"] != "user" or isinstance(item["content"], str):
42
+ continue
43
+ if item["content"][0].endswith(".mp4"):
44
+ video_count += 1
45
+ else:
46
+ image_count += 1
47
+ return image_count, video_count
48
 
49
+
50
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
+ new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
+ history_image_count, history_video_count = count_files_in_history(history)
53
+ image_count = history_image_count + new_image_count
54
+ video_count = history_video_count + new_video_count
55
+ if video_count > 1:
56
+ gr.Warning("Only one video is supported.")
57
+ return False
58
+ if video_count == 1:
59
+ if image_count > 0:
60
+ gr.Warning("Mixing images and videos is not allowed.")
61
+ return False
62
+ if "<image>" in message["text"]:
63
+ gr.Warning("Using <image> tags with video files is not supported.")
64
+ return False
65
+ # TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003
66
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
67
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
68
+ return False
69
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
70
+ gr.Warning("The number of <image> tags in the text does not match the number of images.")
71
+ return False
72
+ return True
73
+
74
+
75
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
76
+ vidcap = cv2.VideoCapture(video_path)
77
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
78
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
79
+
80
+ frame_interval = int(fps / 3)
81
  frames = []
82
+
83
+ for i in range(0, total_frames, frame_interval):
84
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
85
+ success, image = vidcap.read()
86
+ if success:
87
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
+ pil_image = Image.fromarray(image)
89
+ timestamp = round(i / fps, 2)
90
+ frames.append((pil_image, timestamp))
91
+
92
+ vidcap.release()
93
  return frames
94
 
95
+
96
+ def process_video(video_path: str) -> list[dict]:
97
+ content = []
98
+ frames = downsample_video(video_path)
99
+ for frame in frames:
100
+ pil_image, timestamp = frame
101
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
102
+ pil_image.save(temp_file.name)
103
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
104
+ content.append({"type": "image", "url": temp_file.name})
105
+ logger.debug(f"{content=}")
106
+ return content
107
+
108
+
109
+ def process_interleaved_images(message: dict) -> list[dict]:
110
+ logger.debug(f"{message['files']=}")
111
+ parts = re.split(r"(<image>)", message["text"])
112
+ logger.debug(f"{parts=}")
113
+
114
+ content = []
115
+ image_index = 0
116
+ for part in parts:
117
+ logger.debug(f"{part=}")
118
+ if part == "<image>":
119
+ content.append({"type": "image", "url": message["files"][image_index]})
120
+ logger.debug(f"file: {message['files'][image_index]}")
121
+ image_index += 1
122
+ elif part.strip():
123
+ content.append({"type": "text", "text": part.strip()})
124
+ elif isinstance(part, str) and part != "<image>":
125
+ content.append({"type": "text", "text": part})
126
+ logger.debug(f"{content=}")
127
+ return content
128
+
129
+
130
+ def process_new_user_message(message: dict) -> list[dict]:
131
+ if not message["files"]:
132
+ return [{"type": "text", "text": message["text"]}]
133
+
134
+ if message["files"][0].endswith(".mp4"):
135
+ return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
136
+
137
+ if "<image>" in message["text"]:
138
+ return process_interleaved_images(message)
139
+
140
+ return [
141
+ {"type": "text", "text": message["text"]},
142
+ *[{"type": "image", "url": path} for path in message["files"]],
143
+ ]
144
+
145
+
146
+ def process_history(history: list[dict]) -> list[dict]:
147
+ messages = []
148
+ current_user_content: list[dict] = []
149
+ for item in history:
150
+ if item["role"] == "assistant":
151
+ if current_user_content:
152
+ messages.append({"role": "user", "content": current_user_content})
153
+ current_user_content = []
154
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
155
+ else:
156
+ content = item["content"]
157
+ if isinstance(content, str):
158
+ current_user_content.append({"type": "text", "text": content})
159
+ else:
160
+ current_user_content.append({"type": "image", "url": content[0]})
161
+ return messages
162
+
163
+
164
+ @spaces.GPU(duration=120)
165
+ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
166
+ if not validate_media_constraints(message, history):
167
+ yield ""
168
  return
169
 
170
+ messages = []
171
+ if system_prompt:
172
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
173
+ messages.extend(process_history(history))
174
+ messages.append({"role": "user", "content": process_new_user_message(message)})
175
+
176
+ inputs = processor.apply_chat_template(
177
+ messages,
178
+ add_generation_prompt=True,
179
+ tokenize=True,
180
+ return_dict=True,
181
+ return_tensors="pt",
182
+ ).to(device=model.device, dtype=torch.bfloat16)
183
+
184
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
185
+ generate_kwargs = dict(
186
+ inputs,
187
+ streamer=streamer,
188
+ max_new_tokens=max_new_tokens,
189
+ )
190
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
191
+ t.start()
192
+
193
+ output = ""
194
+ for delta in streamer:
195
+ output += delta
196
+ yield output
197
+
198
+
199
+ examples = [
200
+ [
201
+ {
202
+ "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.",
203
+ "files": [],
204
+ }
205
+ ],
206
+ [
207
+ {
208
+ "text": "Write the matplotlib code to generate the same bar chart.",
209
+ "files": ["assets/additional-examples/barchart.png"],
210
+ }
211
+ ],
212
+ [
213
+ {
214
+ "text": "What is odd about this video?",
215
+ "files": ["assets/additional-examples/tmp.mp4"],
216
+ }
217
+ ],
218
+ [
219
+ {
220
+ "text": "I already have this supplement <image> and I want to buy this one <image>. Any warnings I should know about?",
221
+ "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
222
+ }
223
+ ],
224
+ [
225
+ {
226
+ "text": "Write a poem inspired by the visual elements of the images.",
227
+ "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
228
+ }
229
+ ],
230
+ [
231
+ {
232
+ "text": "Compose a short musical piece inspired by the visual elements of the images.",
233
+ "files": [
234
+ "assets/sample-images/07-1.png",
235
+ "assets/sample-images/07-2.png",
236
+ "assets/sample-images/07-3.png",
237
+ "assets/sample-images/07-4.png",
238
+ ],
239
+ }
240
+ ],
241
+ [
242
+ {
243
+ "text": "Write a short story about what might have happened in this house.",
244
+ "files": ["assets/sample-images/08.png"],
245
+ }
246
+ ],
247
+ [
248
+ {
249
+ "text": "Create a short story based on the sequence of images.",
250
+ "files": [
251
+ "assets/sample-images/09-1.png",
252
+ "assets/sample-images/09-2.png",
253
+ "assets/sample-images/09-3.png",
254
+ "assets/sample-images/09-4.png",
255
+ "assets/sample-images/09-5.png",
256
+ ],
257
+ }
258
+ ],
259
+ [
260
+ {
261
+ "text": "Describe the creatures that would live in this world.",
262
+ "files": ["assets/sample-images/10.png"],
263
+ }
264
+ ],
265
+ [
266
+ {
267
+ "text": "Read text in the image.",
268
+ "files": ["assets/additional-examples/1.png"],
269
+ }
270
+ ],
271
+ [
272
+ {
273
+ "text": "When is this ticket dated and how much did it cost?",
274
+ "files": ["assets/additional-examples/2.png"],
275
+ }
276
+ ],
277
+ [
278
+ {
279
+ "text": "Read the text in the image into markdown.",
280
+ "files": ["assets/additional-examples/3.png"],
281
+ }
282
+ ],
283
+ [
284
+ {
285
+ "text": "Evaluate this integral.",
286
+ "files": ["assets/additional-examples/4.png"],
287
+ }
288
+ ],
289
+ [
290
+ {
291
+ "text": "caption this image",
292
+ "files": ["assets/sample-images/01.png"],
293
+ }
294
+ ],
295
+ [
296
+ {
297
+ "text": "What's the sign says?",
298
+ "files": ["assets/sample-images/02.png"],
299
+ }
300
+ ],
301
+ [
302
+ {
303
+ "text": "Compare and contrast the two images.",
304
+ "files": ["assets/sample-images/03.png"],
305
+ }
306
+ ],
307
+ [
308
+ {
309
+ "text": "List all the objects in the image and their colors.",
310
+ "files": ["assets/sample-images/04.png"],
311
+ }
312
+ ],
313
+ [
314
+ {
315
+ "text": "Describe the atmosphere of the scene.",
316
+ "files": ["assets/sample-images/05.png"],
317
+ }
318
+ ],
319
+ ]
320
+
321
+ DESCRIPTION = """\
322
+ <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
323
+
324
+ This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
325
+ You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.
326
+ """
327
+
328
  demo = gr.ChatInterface(
329
+ fn=run,
330
+ type="messages",
331
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
332
+ textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
 
 
 
 
 
333
  multimodal=True,
334
  additional_inputs=[
335
+ gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
336
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
337
+ ],
338
+ stop_btn=False,
339
+ title="Gemma 3 12B IT",
340
+ description=DESCRIPTION,
341
+ examples=examples,
342
+ run_examples_on_click=False,
343
+ cache_examples=False,
344
+ css_paths="style.css",
345
+ delete_cache=(1800, 1800),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  )
347
 
348
+ if __name__ == "__main__":
349
+ demo.launch()