Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -14,16 +14,17 @@ from loguru import logger
|
|
14 |
from PIL import Image
|
15 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
16 |
|
|
|
17 |
model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
|
18 |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
19 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
20 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
21 |
)
|
|
|
22 |
|
23 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
24 |
|
25 |
-
|
26 |
-
css= '''h1 {
|
27 |
text-align: center;
|
28 |
display: block;
|
29 |
}
|
@@ -37,7 +38,6 @@ css= '''h1 {
|
|
37 |
'''
|
38 |
|
39 |
|
40 |
-
|
41 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
42 |
image_count = 0
|
43 |
video_count = 0
|
@@ -77,7 +77,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
77 |
if "<image>" in message["text"]:
|
78 |
gr.Warning("Using <image> tags with video files is not supported.")
|
79 |
return False
|
80 |
-
# TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003
|
81 |
if video_count == 0 and image_count > MAX_NUM_IMAGES:
|
82 |
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
83 |
return False
|
@@ -91,19 +90,17 @@ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
|
91 |
vidcap = cv2.VideoCapture(video_path)
|
92 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
93 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
94 |
-
|
95 |
-
frame_interval = int(fps / 3)
|
96 |
frames = []
|
97 |
-
|
98 |
for i in range(0, total_frames, frame_interval):
|
99 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
100 |
success, image = vidcap.read()
|
101 |
if success:
|
102 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
103 |
pil_image = Image.fromarray(image)
|
104 |
-
timestamp = round(i / fps, 2)
|
105 |
frames.append((pil_image, timestamp))
|
106 |
-
|
107 |
vidcap.release()
|
108 |
return frames
|
109 |
|
@@ -125,7 +122,6 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
125 |
logger.debug(f"{message['files']=}")
|
126 |
parts = re.split(r"(<image>)", message["text"])
|
127 |
logger.debug(f"{parts=}")
|
128 |
-
|
129 |
content = []
|
130 |
image_index = 0
|
131 |
for part in parts:
|
@@ -145,13 +141,10 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
145 |
def process_new_user_message(message: dict) -> list[dict]:
|
146 |
if not message["files"]:
|
147 |
return [{"type": "text", "text": message["text"]}]
|
148 |
-
|
149 |
if message["files"][0].endswith(".mp4"):
|
150 |
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
151 |
-
|
152 |
if "<image>" in message["text"]:
|
153 |
return process_interleaved_images(message)
|
154 |
-
|
155 |
return [
|
156 |
{"type": "text", "text": message["text"]},
|
157 |
*[{"type": "image", "url": path} for path in message["files"]],
|
@@ -173,9 +166,18 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
173 |
current_user_content.append({"type": "text", "text": content})
|
174 |
else:
|
175 |
current_user_content.append({"type": "image", "url": content[0]})
|
|
|
|
|
176 |
return messages
|
177 |
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
@spaces.GPU(duration=120)
|
180 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
181 |
if not validate_media_constraints(message, history):
|
@@ -198,11 +200,12 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
198 |
|
199 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
200 |
generate_kwargs = dict(
|
201 |
-
inputs,
|
202 |
streamer=streamer,
|
203 |
max_new_tokens=max_new_tokens,
|
204 |
)
|
205 |
-
|
|
|
206 |
t.start()
|
207 |
|
208 |
output = ""
|
|
|
14 |
from PIL import Image
|
15 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
16 |
|
17 |
+
# Load model and processor.
|
18 |
model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
|
19 |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
20 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
21 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
22 |
)
|
23 |
+
model.eval() # Ensure the model is in evaluation mode.
|
24 |
|
25 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
26 |
|
27 |
+
css = '''h1 {
|
|
|
28 |
text-align: center;
|
29 |
display: block;
|
30 |
}
|
|
|
38 |
'''
|
39 |
|
40 |
|
|
|
41 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
42 |
image_count = 0
|
43 |
video_count = 0
|
|
|
77 |
if "<image>" in message["text"]:
|
78 |
gr.Warning("Using <image> tags with video files is not supported.")
|
79 |
return False
|
|
|
80 |
if video_count == 0 and image_count > MAX_NUM_IMAGES:
|
81 |
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
82 |
return False
|
|
|
90 |
vidcap = cv2.VideoCapture(video_path)
|
91 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
92 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
93 |
+
# Calculate frame interval (approximately one frame every 1/3 second).
|
94 |
+
frame_interval = int(fps / 3) if fps > 0 else 1
|
95 |
frames = []
|
|
|
96 |
for i in range(0, total_frames, frame_interval):
|
97 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
98 |
success, image = vidcap.read()
|
99 |
if success:
|
100 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
101 |
pil_image = Image.fromarray(image)
|
102 |
+
timestamp = round(i / fps, 2) if fps > 0 else 0
|
103 |
frames.append((pil_image, timestamp))
|
|
|
104 |
vidcap.release()
|
105 |
return frames
|
106 |
|
|
|
122 |
logger.debug(f"{message['files']=}")
|
123 |
parts = re.split(r"(<image>)", message["text"])
|
124 |
logger.debug(f"{parts=}")
|
|
|
125 |
content = []
|
126 |
image_index = 0
|
127 |
for part in parts:
|
|
|
141 |
def process_new_user_message(message: dict) -> list[dict]:
|
142 |
if not message["files"]:
|
143 |
return [{"type": "text", "text": message["text"]}]
|
|
|
144 |
if message["files"][0].endswith(".mp4"):
|
145 |
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
|
|
146 |
if "<image>" in message["text"]:
|
147 |
return process_interleaved_images(message)
|
|
|
148 |
return [
|
149 |
{"type": "text", "text": message["text"]},
|
150 |
*[{"type": "image", "url": path} for path in message["files"]],
|
|
|
166 |
current_user_content.append({"type": "text", "text": content})
|
167 |
else:
|
168 |
current_user_content.append({"type": "image", "url": content[0]})
|
169 |
+
if current_user_content:
|
170 |
+
messages.append({"role": "user", "content": current_user_content})
|
171 |
return messages
|
172 |
|
173 |
|
174 |
+
def generate_thread(generate_kwargs):
|
175 |
+
# Empty cache to free up memory and run generation under no_grad.
|
176 |
+
torch.cuda.empty_cache()
|
177 |
+
with torch.no_grad():
|
178 |
+
model.generate(**generate_kwargs)
|
179 |
+
|
180 |
+
|
181 |
@spaces.GPU(duration=120)
|
182 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
183 |
if not validate_media_constraints(message, history):
|
|
|
200 |
|
201 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
202 |
generate_kwargs = dict(
|
203 |
+
inputs=inputs,
|
204 |
streamer=streamer,
|
205 |
max_new_tokens=max_new_tokens,
|
206 |
)
|
207 |
+
# Launch generation in a separate thread using our no_grad wrapper.
|
208 |
+
t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
|
209 |
t.start()
|
210 |
|
211 |
output = ""
|