Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from threading import Thread | |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
import spaces | |
import time | |
# Load Model & Processor | |
MODEL_ID = "google/gemma-3-12b-it" | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
model.eval() | |
# Helper Function: Downsample Video | |
def downsample_video(video_path, max_duration=10, num_frames=10): | |
""" | |
Downsamples the video to `num_frames` evenly spaced frames within the first `max_duration` seconds. | |
Returns a list of (PIL Image, timestamp) tuples. | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if fps <= 0 or total_frames <= 0: | |
vidcap.release() | |
return [] | |
# Limit to first `max_duration` seconds | |
max_frames = min(int(fps * max_duration), total_frames) | |
frame_indices = np.linspace(0, max_frames - 1, num_frames, dtype=int) | |
frames = [] | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
# Inference Function | |
def video_inference(video_file): | |
""" | |
Processes the video file and generates a text description based on the first 10 seconds. | |
Returns the generated text. | |
""" | |
if video_file is None: | |
return "No video provided." | |
frames = downsample_video(video_file, max_duration=10, num_frames=10) | |
if not frames: | |
return "Could not read frames from video." | |
# Construct prompt | |
messages = [ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": "Please describe what's happening in this video."}] | |
} | |
] | |
for (image, ts) in frames: | |
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) | |
messages[0]["content"].append({"type": "image", "image": image}) | |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
frame_images = [img for (img, _) in frames] | |
inputs = processor( | |
text=[prompt], | |
images=frame_images, | |
return_tensors="pt", | |
padding=True | |
).to("cuda") | |
# Generate text with streaming | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
time.sleep(0.01) | |
return generated_text | |
# Button Toggle Function | |
def toggle_button(has_result): | |
""" | |
Returns visibility states for start_again_btn and start_btn based on has_result. | |
""" | |
if has_result: | |
return gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True) | |
# Build the Gradio App | |
def build_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# **Gemma-3 Live Video Analysis** | |
Press **Start** to record a short video clip (up to 10 seconds). Stop recording to see the analysis. | |
After the result, press **Start Again** to analyze another clip. | |
""") | |
# State to track if a result has been generated | |
has_result = gr.State(value=False) | |
with gr.Row(): | |
with gr.Column(): | |
video = gr.Video( | |
sources=["webcam"], | |
label="Webcam Recording", | |
format="mp4" | |
) | |
# Two buttons: one for Start, one for Start Again | |
start_btn = gr.Button("Start", visible=True) | |
start_again_btn = gr.Button("Start Again", visible=False) | |
with gr.Column(): | |
output_text = gr.Textbox(label="Model Output") | |
# When video is recorded and stopped, process it | |
def process_video(video_file, has_result_state): | |
if video_file is None: | |
return "Please record a video.", has_result_state | |
result = video_inference(video_file) | |
return result, True | |
video.change( | |
fn=process_video, | |
inputs=[video, has_result], | |
outputs=[output_text, has_result] | |
) | |
# Update button visibility based on has_result | |
has_result.change( | |
fn=toggle_button, | |
inputs=has_result, | |
outputs=[start_again_btn, start_btn] | |
) | |
# Clicking either button resets the video and output | |
def reset_state(): | |
return None, "", False | |
start_btn.click( | |
fn=reset_state, | |
inputs=None, | |
outputs=[video, output_text, has_result] | |
) | |
start_again_btn.click( | |
fn=reset_state, | |
inputs=None, | |
outputs=[video, output_text, has_result] | |
) | |
return demo | |
if __name__ == "__main__": | |
app = build_app() | |
app.launch(debug=True) |