Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
import time | |
import re | |
import spaces | |
from PIL import Image | |
from threading import Thread | |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
##################################### | |
# 1. Load Model & Processor | |
##################################### | |
MODEL_ID = "google/gemma-3-12b-it" # Adjust to your needs | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
model.eval() | |
##################################### | |
# 2. Helper Function: Capture Live Frames | |
##################################### | |
def capture_live_frames(duration=5, num_frames=10): | |
""" | |
Captures live frames from the default webcam for a specified duration. | |
Returns a list of (PIL image, timestamp) tuples. | |
""" | |
cap = cv2.VideoCapture(0) # Use default webcam | |
if not cap.isOpened(): | |
return [] | |
# Try to get FPS, default to 30 if not available. | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
if fps <= 0: | |
fps = 30 | |
total_frames_to_capture = int(duration * fps) | |
frame_indices = np.linspace(0, total_frames_to_capture - 1, num_frames, dtype=int) | |
captured_frames = [] | |
frame_count = 0 | |
start_time = time.time() | |
while frame_count < total_frames_to_capture: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count in frame_indices: | |
# Convert BGR (OpenCV) to RGB (PIL) | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(frame_rgb) | |
timestamp = round(frame_count / fps, 2) | |
captured_frames.append((pil_image, timestamp)) | |
frame_count += 1 | |
# Break if the elapsed time exceeds the duration. | |
if time.time() - start_time > duration: | |
break | |
cap.release() | |
return captured_frames | |
##################################### | |
# 3. Live Inference Function | |
##################################### | |
def live_inference(duration=5): | |
""" | |
Captures live frames from the webcam, builds a prompt, and returns the generated text. | |
""" | |
frames = capture_live_frames(duration=duration, num_frames=10) | |
if not frames: | |
return "Could not capture live frames from the webcam." | |
# Build prompt using the captured frames. | |
messages = [{ | |
"role": "user", | |
"content": [{"type": "text", "text": "Please describe what's happening in this live video."}] | |
}] | |
for (image, ts) in frames: | |
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) | |
messages[0]["content"].append({"type": "image", "image": image}) | |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
frame_images = [img for (img, _) in frames] | |
inputs = processor( | |
text=[prompt], | |
images=frame_images, | |
return_tensors="pt", | |
padding=True | |
).to("cuda") | |
# Generate text using streaming. | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
time.sleep(0.01) | |
return generated_text | |
##################################### | |
# 4. Build Gradio Live App | |
##################################### | |
def build_live_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# **Live Video Analysis**\n\nPress **Start** to capture a few seconds of live video from your webcam and analyze the content.") | |
with gr.Column(): | |
duration_input = gr.Number(label="Capture Duration (seconds)", value=5, precision=0) | |
start_btn = gr.Button("Start") | |
output_text = gr.Textbox(label="Model Output") | |
restart_btn = gr.Button("Start Again", visible=False) | |
# This function triggers the live inference and also makes the restart button visible. | |
def start_inference(duration): | |
text = live_inference(duration) | |
return text, gr.update(visible=True) | |
start_btn.click(fn=start_inference, inputs=duration_input, outputs=[output_text, restart_btn]) | |
restart_btn.click(fn=start_inference, inputs=duration_input, outputs=[output_text, restart_btn]) | |
return demo | |
if __name__ == "__main__": | |
app = build_live_app() | |
app.launch(debug=True) |