Spaces:
Running
on
Zero
Running
on
Zero
ruff
Browse files
app.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
|
|
|
|
|
3 |
from collections.abc import Iterator
|
4 |
from threading import Thread
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
import
|
10 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
11 |
|
12 |
model_id = "google/gemma-3-12b-it"
|
@@ -15,17 +18,13 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
15 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
16 |
)
|
17 |
|
18 |
-
import cv2
|
19 |
-
from PIL import Image
|
20 |
-
import numpy as np
|
21 |
-
import tempfile
|
22 |
|
23 |
def downsample_video(video_path):
|
24 |
vidcap = cv2.VideoCapture(video_path)
|
25 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
26 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
27 |
-
|
28 |
-
frame_interval = int(fps / 3)
|
29 |
frames = []
|
30 |
|
31 |
for i in range(0, total_frames, frame_interval):
|
@@ -34,7 +33,7 @@ def downsample_video(video_path):
|
|
34 |
if success:
|
35 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
36 |
pil_image = Image.fromarray(image)
|
37 |
-
timestamp = round(i / fps, 2)
|
38 |
frames.append((pil_image, timestamp))
|
39 |
|
40 |
vidcap.release()
|
@@ -46,8 +45,8 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
46 |
if "<image>" in message["text"]:
|
47 |
content = []
|
48 |
print("message[files]", message["files"])
|
49 |
-
parts = re.split(r
|
50 |
-
image_index = 0
|
51 |
print("parts", parts)
|
52 |
for part in parts:
|
53 |
print("part", part)
|
@@ -55,29 +54,30 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
55 |
content.append({"type": "image", "url": message["files"][image_index]})
|
56 |
print("file", message["files"][image_index])
|
57 |
image_index += 1
|
58 |
-
elif part.strip():
|
59 |
content.append({"type": "text", "text": part.strip()})
|
60 |
elif isinstance(part, str) and not part == "<image>":
|
61 |
content.append({"type": "text", "text": part})
|
62 |
print(content)
|
63 |
return content
|
64 |
-
|
65 |
content = []
|
66 |
video = message["files"].pop(0)
|
67 |
frames = downsample_video(video)
|
68 |
for frame in frames:
|
69 |
pil_image, timestamp = frame
|
70 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=
|
71 |
pil_image.save(temp_file.name)
|
72 |
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
73 |
content.append({"type": "image", "url": temp_file.name})
|
74 |
print(content)
|
75 |
return content
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
81 |
|
82 |
|
83 |
def process_history(history: list[dict]) -> list[dict]:
|
|
|
1 |
#!/usr/bin/env python
|
2 |
|
3 |
+
import re
|
4 |
+
import tempfile
|
5 |
from collections.abc import Iterator
|
6 |
from threading import Thread
|
7 |
|
8 |
+
import cv2
|
9 |
import gradio as gr
|
10 |
import spaces
|
11 |
import torch
|
12 |
+
from PIL import Image
|
13 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
14 |
|
15 |
model_id = "google/gemma-3-12b-it"
|
|
|
18 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def downsample_video(video_path):
|
23 |
vidcap = cv2.VideoCapture(video_path)
|
24 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
25 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
26 |
+
|
27 |
+
frame_interval = int(fps / 3)
|
28 |
frames = []
|
29 |
|
30 |
for i in range(0, total_frames, frame_interval):
|
|
|
33 |
if success:
|
34 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
pil_image = Image.fromarray(image)
|
36 |
+
timestamp = round(i / fps, 2)
|
37 |
frames.append((pil_image, timestamp))
|
38 |
|
39 |
vidcap.release()
|
|
|
45 |
if "<image>" in message["text"]:
|
46 |
content = []
|
47 |
print("message[files]", message["files"])
|
48 |
+
parts = re.split(r"(<image>)", message["text"])
|
49 |
+
image_index = 0
|
50 |
print("parts", parts)
|
51 |
for part in parts:
|
52 |
print("part", part)
|
|
|
54 |
content.append({"type": "image", "url": message["files"][image_index]})
|
55 |
print("file", message["files"][image_index])
|
56 |
image_index += 1
|
57 |
+
elif part.strip():
|
58 |
content.append({"type": "text", "text": part.strip()})
|
59 |
elif isinstance(part, str) and not part == "<image>":
|
60 |
content.append({"type": "text", "text": part})
|
61 |
print(content)
|
62 |
return content
|
63 |
+
if message["files"][0].endswith(".mp4"):
|
64 |
content = []
|
65 |
video = message["files"].pop(0)
|
66 |
frames = downsample_video(video)
|
67 |
for frame in frames:
|
68 |
pil_image, timestamp = frame
|
69 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
70 |
pil_image.save(temp_file.name)
|
71 |
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
72 |
content.append({"type": "image", "url": temp_file.name})
|
73 |
print(content)
|
74 |
return content
|
75 |
+
# non interleaved images
|
76 |
+
return [
|
77 |
+
{"type": "text", "text": message["text"]},
|
78 |
+
*[{"type": "image", "url": path} for path in message["files"]],
|
79 |
+
]
|
80 |
+
return [{"type": "text", "text": message["text"]}]
|
81 |
|
82 |
|
83 |
def process_history(history: list[dict]) -> list[dict]:
|