Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,194 Bytes
b8a0d2d 8716c2f 5255817 8716c2f 67f9a49 8716c2f 20121ea 8716c2f 67f9a49 8716c2f 67f9a49 8716c2f 5373e26 8716c2f 2aaebb9 8716c2f 67f9a49 8716c2f 554ae5a 8716c2f 67f9a49 8716c2f 554ae5a 8716c2f 67f9a49 8716c2f 8e6677c 8716c2f 67f9a49 554ae5a 67f9a49 554ae5a 67f9a49 554ae5a 67f9a49 8716c2f 67f9a49 8716c2f 67f9a49 8e6677c 554ae5a 8716c2f 67f9a49 8716c2f 67f9a49 8e6677c 8716c2f 50fda8e 8716c2f b8a0d2d 8716c2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import gradio as gr
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
import spaces
import time
import re
from PIL import Image
from threading import Thread
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
#####################################
# 1. Load Model & Processor
#####################################
MODEL_ID = "google/gemma-3-12b-it" # Example model ID (adjust to your needs)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Gemma3ForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda")
model.eval()
#####################################
# 2. Helper Function: Downsample Video
#####################################
def downsample_video(video_path, num_frames=10):
"""
Downsamples the video file to `num_frames` evenly spaced frames.
Each frame is converted to a PIL Image along with its timestamp.
"""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
if total_frames <= 0 or fps <= 0:
vidcap.release()
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
#####################################
# 3. The Inference Function
#####################################
@spaces.GPU
def video_inference(video_file, duration):
"""
- Takes a recorded video file and a chosen duration (string).
- Downsamples the video, passes frames to the model for inference.
- Returns model-generated text + a bar chart based on the text.
"""
if video_file is None:
return "No video provided.", None
# 3.1: Downsample the recorded video
frames = downsample_video(video_file)
if not frames:
return "Could not read frames from video.", None
# 3.2: Construct prompt
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Please describe what's happening in this video."}]
}
]
# Add frames (with timestamp) to the messages
for (image, ts) in frames:
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
messages[0]["content"].append({"type": "image", "image": image})
# Prepare final prompt
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Gather images for the model
frame_images = [img for (img, _) in frames]
inputs = processor(
text=[prompt],
images=frame_images,
return_tensors="pt",
padding=True
).to("cuda")
# 3.3: Generate text output (streaming)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
time.sleep(0.01)
# 3.4: Build a bar chart based on top keywords from the generated text
# (Naive approach: frequency of top 5 words)
words = re.findall(r'\w+', generated_text.lower())
freq = {}
for w in words:
freq[w] = freq.get(w, 0) + 1
# Sort words by frequency (descending)
sorted_items = sorted(freq.items(), key=lambda x: x[1], reverse=True)
# Pick top 5 words (if fewer than 5, pick all)
top5 = sorted_items[:5]
if not top5:
# If there's no text or no valid words, return no chart
return generated_text, None
categories = [item[0] for item in top5]
values = [item[1] for item in top5]
# Create the figure
fig, ax = plt.subplots()
colors = ["#4B0082", "#9370DB", "#8A2BE2", "#DA70D6", "#BA55D3"] # Purple-ish palette
# Make sure we have enough colors for the number of bars
color_list = colors[: len(categories)]
ax.bar(categories, values, color=color_list)
ax.set_title("Top Keywords in Generated Description")
ax.set_ylabel("Frequency")
ax.set_xlabel("Keyword")
# Return the final text and the figure
return generated_text, fig
#####################################
# 4. Build a Professional Gradio UI
#####################################
def build_app():
with gr.Blocks() as demo:
gr.Markdown("""
# **Gemma-3 (Example) Live Video Analysis**
Record a video (from webcam or file), then click **Stop**.
Next, click **Analyze** to run the model and see textual + chart outputs.
""")
with gr.Row():
with gr.Column():
duration = gr.Radio(
choices=["5", "10", "20", "30"],
value="5",
label="Suggested Recording Duration (seconds)",
info="Select how long you plan to record before pressing Stop."
)
# For older Gradio versions, avoid `source="webcam"`.
video = gr.Video(
label="Webcam Recording (press the Record button, then Stop)",
format="mp4"
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Model Output")
output_plot = gr.Plot(label="Analytics Chart")
analyze_btn.click(
fn=video_inference,
inputs=[video, duration],
outputs=[output_text, output_plot]
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(debug=True) |