Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -14,13 +14,11 @@ from loguru import logger
|
|
14 |
from PIL import Image
|
15 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
16 |
|
17 |
-
# Load model and processor.
|
18 |
model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
|
19 |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
20 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
21 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
22 |
)
|
23 |
-
model.eval() # Set model to evaluation mode.
|
24 |
|
25 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
26 |
|
@@ -37,7 +35,6 @@ css = '''h1 {
|
|
37 |
}
|
38 |
'''
|
39 |
|
40 |
-
|
41 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
42 |
image_count = 0
|
43 |
video_count = 0
|
@@ -48,7 +45,6 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
|
48 |
image_count += 1
|
49 |
return image_count, video_count
|
50 |
|
51 |
-
|
52 |
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
53 |
image_count = 0
|
54 |
video_count = 0
|
@@ -61,7 +57,6 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
|
61 |
image_count += 1
|
62 |
return image_count, video_count
|
63 |
|
64 |
-
|
65 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
66 |
new_image_count, new_video_count = count_files_in_new_message(message["files"])
|
67 |
history_image_count, history_video_count = count_files_in_history(history)
|
@@ -85,26 +80,30 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
85 |
return False
|
86 |
return True
|
87 |
|
88 |
-
|
89 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
90 |
vidcap = cv2.VideoCapture(video_path)
|
91 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
92 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
frames = []
|
96 |
-
for i in
|
97 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
98 |
success, image = vidcap.read()
|
99 |
if success:
|
100 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
101 |
pil_image = Image.fromarray(image)
|
102 |
-
timestamp = round(i / fps, 2)
|
103 |
frames.append((pil_image, timestamp))
|
|
|
104 |
vidcap.release()
|
105 |
return frames
|
106 |
|
107 |
-
|
108 |
def process_video(video_path: str) -> list[dict]:
|
109 |
content = []
|
110 |
frames = downsample_video(video_path)
|
@@ -117,11 +116,11 @@ def process_video(video_path: str) -> list[dict]:
|
|
117 |
logger.debug(f"{content=}")
|
118 |
return content
|
119 |
|
120 |
-
|
121 |
def process_interleaved_images(message: dict) -> list[dict]:
|
122 |
logger.debug(f"{message['files']=}")
|
123 |
parts = re.split(r"(<image>)", message["text"])
|
124 |
logger.debug(f"{parts=}")
|
|
|
125 |
content = []
|
126 |
image_index = 0
|
127 |
for part in parts:
|
@@ -137,20 +136,21 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
137 |
logger.debug(f"{content=}")
|
138 |
return content
|
139 |
|
140 |
-
|
141 |
def process_new_user_message(message: dict) -> list[dict]:
|
142 |
if not message["files"]:
|
143 |
return [{"type": "text", "text": message["text"]}]
|
|
|
144 |
if message["files"][0].endswith(".mp4"):
|
145 |
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
|
|
146 |
if "<image>" in message["text"]:
|
147 |
return process_interleaved_images(message)
|
|
|
148 |
return [
|
149 |
{"type": "text", "text": message["text"]},
|
150 |
*[{"type": "image", "url": path} for path in message["files"]],
|
151 |
]
|
152 |
|
153 |
-
|
154 |
def process_history(history: list[dict]) -> list[dict]:
|
155 |
messages = []
|
156 |
current_user_content: list[dict] = []
|
@@ -166,18 +166,8 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
166 |
current_user_content.append({"type": "text", "text": content})
|
167 |
else:
|
168 |
current_user_content.append({"type": "image", "url": content[0]})
|
169 |
-
if current_user_content:
|
170 |
-
messages.append({"role": "user", "content": current_user_content})
|
171 |
return messages
|
172 |
|
173 |
-
|
174 |
-
def generate_thread(generate_kwargs):
|
175 |
-
# Clear cache and run generation under no_grad.
|
176 |
-
torch.cuda.empty_cache()
|
177 |
-
with torch.no_grad():
|
178 |
-
model.generate(**generate_kwargs)
|
179 |
-
|
180 |
-
|
181 |
@spaces.GPU(duration=120)
|
182 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
183 |
if not validate_media_constraints(message, history):
|
@@ -190,21 +180,21 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
190 |
messages.extend(process_history(history))
|
191 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
192 |
|
193 |
-
|
194 |
-
raw_inputs = processor.apply_chat_template(
|
195 |
messages,
|
196 |
add_generation_prompt=True,
|
197 |
tokenize=True,
|
198 |
return_dict=True,
|
199 |
return_tensors="pt",
|
200 |
-
)
|
201 |
-
inputs = {k: v.to(device=model.device, dtype=torch.bfloat16) for k, v in raw_inputs.items()}
|
202 |
|
203 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
208 |
t.start()
|
209 |
|
210 |
output = ""
|
@@ -212,7 +202,6 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
212 |
output += delta
|
213 |
yield output
|
214 |
|
215 |
-
|
216 |
examples = [
|
217 |
[
|
218 |
{
|
@@ -339,7 +328,7 @@ DESCRIPTION = """\
|
|
339 |
<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
|
340 |
|
341 |
This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
|
342 |
-
You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.
|
343 |
"""
|
344 |
|
345 |
demo = gr.ChatInterface(
|
|
|
14 |
from PIL import Image
|
15 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
16 |
|
|
|
17 |
model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
|
18 |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
19 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
20 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
21 |
)
|
|
|
22 |
|
23 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
24 |
|
|
|
35 |
}
|
36 |
'''
|
37 |
|
|
|
38 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
39 |
image_count = 0
|
40 |
video_count = 0
|
|
|
45 |
image_count += 1
|
46 |
return image_count, video_count
|
47 |
|
|
|
48 |
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
49 |
image_count = 0
|
50 |
video_count = 0
|
|
|
57 |
image_count += 1
|
58 |
return image_count, video_count
|
59 |
|
|
|
60 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
61 |
new_image_count, new_video_count = count_files_in_new_message(message["files"])
|
62 |
history_image_count, history_video_count = count_files_in_history(history)
|
|
|
80 |
return False
|
81 |
return True
|
82 |
|
|
|
83 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
84 |
vidcap = cv2.VideoCapture(video_path)
|
85 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
86 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
87 |
+
|
88 |
+
max_frames = 5 # Limit to 5 frames to prevent memory overload
|
89 |
+
if total_frames <= max_frames:
|
90 |
+
indices = list(range(total_frames))
|
91 |
+
else:
|
92 |
+
indices = [int(i * (total_frames - 1) / (max_frames - 1)) for i in range(max_frames)]
|
93 |
+
|
94 |
frames = []
|
95 |
+
for i in indices:
|
96 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
97 |
success, image = vidcap.read()
|
98 |
if success:
|
99 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
100 |
pil_image = Image.fromarray(image)
|
101 |
+
timestamp = round(i / fps, 2)
|
102 |
frames.append((pil_image, timestamp))
|
103 |
+
|
104 |
vidcap.release()
|
105 |
return frames
|
106 |
|
|
|
107 |
def process_video(video_path: str) -> list[dict]:
|
108 |
content = []
|
109 |
frames = downsample_video(video_path)
|
|
|
116 |
logger.debug(f"{content=}")
|
117 |
return content
|
118 |
|
|
|
119 |
def process_interleaved_images(message: dict) -> list[dict]:
|
120 |
logger.debug(f"{message['files']=}")
|
121 |
parts = re.split(r"(<image>)", message["text"])
|
122 |
logger.debug(f"{parts=}")
|
123 |
+
|
124 |
content = []
|
125 |
image_index = 0
|
126 |
for part in parts:
|
|
|
136 |
logger.debug(f"{content=}")
|
137 |
return content
|
138 |
|
|
|
139 |
def process_new_user_message(message: dict) -> list[dict]:
|
140 |
if not message["files"]:
|
141 |
return [{"type": "text", "text": message["text"]}]
|
142 |
+
|
143 |
if message["files"][0].endswith(".mp4"):
|
144 |
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
145 |
+
|
146 |
if "<image>" in message["text"]:
|
147 |
return process_interleaved_images(message)
|
148 |
+
|
149 |
return [
|
150 |
{"type": "text", "text": message["text"]},
|
151 |
*[{"type": "image", "url": path} for path in message["files"]],
|
152 |
]
|
153 |
|
|
|
154 |
def process_history(history: list[dict]) -> list[dict]:
|
155 |
messages = []
|
156 |
current_user_content: list[dict] = []
|
|
|
166 |
current_user_content.append({"type": "text", "text": content})
|
167 |
else:
|
168 |
current_user_content.append({"type": "image", "url": content[0]})
|
|
|
|
|
169 |
return messages
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
@spaces.GPU(duration=120)
|
172 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
173 |
if not validate_media_constraints(message, history):
|
|
|
180 |
messages.extend(process_history(history))
|
181 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
182 |
|
183 |
+
inputs = processor.apply_chat_template(
|
|
|
184 |
messages,
|
185 |
add_generation_prompt=True,
|
186 |
tokenize=True,
|
187 |
return_dict=True,
|
188 |
return_tensors="pt",
|
189 |
+
).to(device=model.device, dtype=torch.bfloat16)
|
|
|
190 |
|
191 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
192 |
+
generate_kwargs = dict(
|
193 |
+
inputs,
|
194 |
+
streamer=streamer,
|
195 |
+
max_new_tokens=max_new_tokens,
|
196 |
+
)
|
197 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
198 |
t.start()
|
199 |
|
200 |
output = ""
|
|
|
202 |
output += delta
|
203 |
yield output
|
204 |
|
|
|
205 |
examples = [
|
206 |
[
|
207 |
{
|
|
|
328 |
<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
|
329 |
|
330 |
This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
|
331 |
+
You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input. For videos, up to 5 frames will be extracted and processed.
|
332 |
"""
|
333 |
|
334 |
demo = gr.ChatInterface(
|