Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,15 +6,16 @@ import matplotlib.pyplot as plt
|
|
6 |
import random
|
7 |
import spaces
|
8 |
import time
|
|
|
9 |
from PIL import Image
|
10 |
from threading import Thread
|
11 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
12 |
from transformers.image_utils import load_image
|
13 |
|
14 |
#####################################
|
15 |
-
# 1. Load
|
16 |
#####################################
|
17 |
-
MODEL_ID = "google/gemma-3-12b-it" # Example
|
18 |
|
19 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
20 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
@@ -52,29 +53,6 @@ def downsample_video(video_path, num_frames=10):
|
|
52 |
vidcap.release()
|
53 |
return frames
|
54 |
|
55 |
-
#####################################
|
56 |
-
# 2.5: Parse Categories from Model Output
|
57 |
-
#####################################
|
58 |
-
def parse_inferred_categories(generated_text):
|
59 |
-
"""
|
60 |
-
A naive parser that looks for lines starting with 'Category:'
|
61 |
-
and collects the text after that as the category name.
|
62 |
-
Example lines in model output:
|
63 |
-
Category: Nutrition
|
64 |
-
Category: Outdoor Scenes
|
65 |
-
Returns a list of category strings.
|
66 |
-
"""
|
67 |
-
categories = []
|
68 |
-
for line in generated_text.split("\n"):
|
69 |
-
line = line.strip()
|
70 |
-
# Check if the line starts with 'Category:' (case-insensitive)
|
71 |
-
if line.lower().startswith("category:"):
|
72 |
-
# Extract everything after 'Category:'
|
73 |
-
cat = line.split(":", 1)[1].strip()
|
74 |
-
if cat:
|
75 |
-
categories.append(cat)
|
76 |
-
return categories
|
77 |
-
|
78 |
#####################################
|
79 |
# 3. The Inference Function
|
80 |
#####################################
|
@@ -82,8 +60,8 @@ def parse_inferred_categories(generated_text):
|
|
82 |
def video_inference(video_file, duration):
|
83 |
"""
|
84 |
- Takes a recorded video file and a chosen duration (string).
|
85 |
-
- Downsamples the video, passes frames to the
|
86 |
-
- Returns model-generated text + a bar chart
|
87 |
"""
|
88 |
if video_file is None:
|
89 |
return "No video provided.", None
|
@@ -100,6 +78,7 @@ def video_inference(video_file, duration):
|
|
100 |
"content": [{"type": "text", "text": "Please describe what's happening in this video."}]
|
101 |
}
|
102 |
]
|
|
|
103 |
# Add frames (with timestamp) to the messages
|
104 |
for (image, ts) in frames:
|
105 |
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
|
@@ -108,7 +87,7 @@ def video_inference(video_file, duration):
|
|
108 |
# Prepare final prompt
|
109 |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
110 |
|
111 |
-
#
|
112 |
frame_images = [img for (img, _) in frames]
|
113 |
|
114 |
inputs = processor(
|
@@ -130,23 +109,37 @@ def video_inference(video_file, duration):
|
|
130 |
generated_text += new_text
|
131 |
time.sleep(0.01)
|
132 |
|
133 |
-
# 3.4:
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
|
|
141 |
|
142 |
-
|
|
|
|
|
|
|
143 |
fig, ax = plt.subplots()
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
149 |
|
|
|
150 |
return generated_text, fig
|
151 |
|
152 |
#####################################
|
@@ -155,7 +148,7 @@ def video_inference(video_file, duration):
|
|
155 |
def build_app():
|
156 |
with gr.Blocks() as demo:
|
157 |
gr.Markdown("""
|
158 |
-
# **
|
159 |
Record a video (from webcam or file), then click **Stop**.
|
160 |
Next, click **Analyze** to run the model and see textual + chart outputs.
|
161 |
""")
|
@@ -168,8 +161,9 @@ def build_app():
|
|
168 |
label="Suggested Recording Duration (seconds)",
|
169 |
info="Select how long you plan to record before pressing Stop."
|
170 |
)
|
|
|
171 |
video = gr.Video(
|
172 |
-
label="Webcam Recording (press Record, then Stop)",
|
173 |
format="mp4"
|
174 |
)
|
175 |
analyze_btn = gr.Button("Analyze", variant="primary")
|
|
|
6 |
import random
|
7 |
import spaces
|
8 |
import time
|
9 |
+
import re
|
10 |
from PIL import Image
|
11 |
from threading import Thread
|
12 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
13 |
from transformers.image_utils import load_image
|
14 |
|
15 |
#####################################
|
16 |
+
# 1. Load Model & Processor
|
17 |
#####################################
|
18 |
+
MODEL_ID = "google/gemma-3-12b-it" # Example model ID (adjust to your needs)
|
19 |
|
20 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
21 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
|
53 |
vidcap.release()
|
54 |
return frames
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
#####################################
|
57 |
# 3. The Inference Function
|
58 |
#####################################
|
|
|
60 |
def video_inference(video_file, duration):
|
61 |
"""
|
62 |
- Takes a recorded video file and a chosen duration (string).
|
63 |
+
- Downsamples the video, passes frames to the model for inference.
|
64 |
+
- Returns model-generated text + a bar chart based on the text.
|
65 |
"""
|
66 |
if video_file is None:
|
67 |
return "No video provided.", None
|
|
|
78 |
"content": [{"type": "text", "text": "Please describe what's happening in this video."}]
|
79 |
}
|
80 |
]
|
81 |
+
|
82 |
# Add frames (with timestamp) to the messages
|
83 |
for (image, ts) in frames:
|
84 |
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
|
|
|
87 |
# Prepare final prompt
|
88 |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
89 |
|
90 |
+
# Gather images for the model
|
91 |
frame_images = [img for (img, _) in frames]
|
92 |
|
93 |
inputs = processor(
|
|
|
109 |
generated_text += new_text
|
110 |
time.sleep(0.01)
|
111 |
|
112 |
+
# 3.4: Build a bar chart based on top keywords from the generated text
|
113 |
+
# (Naive approach: frequency of top 5 words)
|
114 |
+
words = re.findall(r'\w+', generated_text.lower())
|
115 |
+
freq = {}
|
116 |
+
for w in words:
|
117 |
+
freq[w] = freq.get(w, 0) + 1
|
118 |
+
|
119 |
+
# Sort words by frequency (descending)
|
120 |
+
sorted_items = sorted(freq.items(), key=lambda x: x[1], reverse=True)
|
121 |
+
# Pick top 5 words (if fewer than 5, pick all)
|
122 |
+
top5 = sorted_items[:5]
|
123 |
|
124 |
+
if not top5:
|
125 |
+
# If there's no text or no valid words, return no chart
|
126 |
+
return generated_text, None
|
127 |
|
128 |
+
categories = [item[0] for item in top5]
|
129 |
+
values = [item[1] for item in top5]
|
130 |
+
|
131 |
+
# Create the figure
|
132 |
fig, ax = plt.subplots()
|
133 |
+
colors = ["#4B0082", "#9370DB", "#8A2BE2", "#DA70D6", "#BA55D3"] # Purple-ish palette
|
134 |
+
# Make sure we have enough colors for the number of bars
|
135 |
+
color_list = colors[: len(categories)]
|
136 |
+
|
137 |
+
ax.bar(categories, values, color=color_list)
|
138 |
+
ax.set_title("Top Keywords in Generated Description")
|
139 |
+
ax.set_ylabel("Frequency")
|
140 |
+
ax.set_xlabel("Keyword")
|
141 |
|
142 |
+
# Return the final text and the figure
|
143 |
return generated_text, fig
|
144 |
|
145 |
#####################################
|
|
|
148 |
def build_app():
|
149 |
with gr.Blocks() as demo:
|
150 |
gr.Markdown("""
|
151 |
+
# **Gemma-3 (Example) Live Video Analysis**
|
152 |
Record a video (from webcam or file), then click **Stop**.
|
153 |
Next, click **Analyze** to run the model and see textual + chart outputs.
|
154 |
""")
|
|
|
161 |
label="Suggested Recording Duration (seconds)",
|
162 |
info="Select how long you plan to record before pressing Stop."
|
163 |
)
|
164 |
+
# For older Gradio versions, avoid `source="webcam"`.
|
165 |
video = gr.Video(
|
166 |
+
label="Webcam Recording (press the Record button, then Stop)",
|
167 |
format="mp4"
|
168 |
)
|
169 |
analyze_btn = gr.Button("Analyze", variant="primary")
|