import gradio as gr import torch import numpy as np import cv2 import matplotlib.pyplot as plt import random import spaces import time import re from PIL import Image from threading import Thread from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer from transformers.image_utils import load_image ##################################### # 1. Load Model & Processor ##################################### MODEL_ID = "google/gemma-3-12b-it" # Example model ID (adjust to your needs) processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = Gemma3ForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to("cuda") model.eval() ##################################### # 2. Helper Function: Downsample Video ##################################### def downsample_video(video_path, num_frames=10): """ Downsamples the video file to `num_frames` evenly spaced frames. Each frame is converted to a PIL Image along with its timestamp. """ vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] if total_frames <= 0 or fps <= 0: vidcap.release() return frames frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames ##################################### # 3. The Inference Function ##################################### @spaces.GPU def video_inference(video_file, duration): """ - Takes a recorded video file and a chosen duration (string). - Downsamples the video, passes frames to the model for inference. - Returns model-generated text + a bar chart based on the text. """ if video_file is None: return "No video provided.", None # 3.1: Downsample the recorded video frames = downsample_video(video_file) if not frames: return "Could not read frames from video.", None # 3.2: Construct prompt messages = [ { "role": "user", "content": [{"type": "text", "text": "Please describe what's happening in this video."}] } ] # Add frames (with timestamp) to the messages for (image, ts) in frames: messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) messages[0]["content"].append({"type": "image", "image": image}) # Prepare final prompt prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Gather images for the model frame_images = [img for (img, _) in frames] inputs = processor( text=[prompt], images=frame_images, return_tensors="pt", padding=True ).to("cuda") # 3.3: Generate text output (streaming) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text time.sleep(0.01) # 3.4: Build a bar chart based on top keywords from the generated text # (Naive approach: frequency of top 5 words) words = re.findall(r'\w+', generated_text.lower()) freq = {} for w in words: freq[w] = freq.get(w, 0) + 1 # Sort words by frequency (descending) sorted_items = sorted(freq.items(), key=lambda x: x[1], reverse=True) # Pick top 5 words (if fewer than 5, pick all) top5 = sorted_items[:5] if not top5: # If there's no text or no valid words, return no chart return generated_text, None categories = [item[0] for item in top5] values = [item[1] for item in top5] # Create the figure fig, ax = plt.subplots() colors = ["#4B0082", "#9370DB", "#8A2BE2", "#DA70D6", "#BA55D3"] # Purple-ish palette # Make sure we have enough colors for the number of bars color_list = colors[: len(categories)] ax.bar(categories, values, color=color_list) ax.set_title("Top Keywords in Generated Description") ax.set_ylabel("Frequency") ax.set_xlabel("Keyword") # Return the final text and the figure return generated_text, fig ##################################### # 4. Build a Professional Gradio UI ##################################### def build_app(): with gr.Blocks() as demo: gr.Markdown(""" # **Gemma-3 (Example) Live Video Analysis** Record a video (from webcam or file), then click **Stop**. Next, click **Analyze** to run the model and see textual + chart outputs. """) with gr.Row(): with gr.Column(): duration = gr.Radio( choices=["5", "10", "20", "30"], value="5", label="Suggested Recording Duration (seconds)", info="Select how long you plan to record before pressing Stop." ) # For older Gradio versions, avoid `source="webcam"`. video = gr.Video( label="Webcam Recording (press the Record button, then Stop)", format="mp4" ) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Model Output") output_plot = gr.Plot(label="Analytics Chart") analyze_btn.click( fn=video_inference, inputs=[video, duration], outputs=[output_text, output_plot] ) return demo if __name__ == "__main__": app = build_app() app.launch(debug=True)