Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
import random | |
import time | |
from PIL import Image | |
from threading import Thread | |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
from transformers.image_utils import load_image | |
##################################### | |
# 1. Load Qwen2.5-VL Model & Processor | |
##################################### | |
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # or "Qwen/Qwen2.5-VL-3B-Instruct" | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
model.eval() | |
##################################### | |
# 2. Helper Function: Downsample Video | |
##################################### | |
def downsample_video(video_path, num_frames=10): | |
""" | |
Downsamples the video file to `num_frames` evenly spaced frames. | |
Each frame is converted to a PIL Image along with its timestamp. | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
if total_frames <= 0 or fps <= 0: | |
vidcap.release() | |
return frames | |
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
##################################### | |
# 3. The Inference Function | |
##################################### | |
def video_inference(video_file, duration): | |
""" | |
- Takes a recorded video file and a chosen duration (string). | |
- Downsamples the video, passes frames to Qwen2.5-VL for inference. | |
- Returns model-generated text + a dummy bar chart as example analytics. | |
""" | |
if video_file is None: | |
return "No video provided.", None | |
# 3.1: Downsample the recorded video | |
frames = downsample_video(video_file) | |
if not frames: | |
return "Could not read frames from video.", None | |
# 3.2: Construct Qwen2.5-VL prompt | |
# We'll do a simple prompt: "Please describe what's happening in this video." | |
messages = [ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": "Please describe what's happening in this video."}] | |
} | |
] | |
# Add frames (with timestamp) to the messages | |
for (image, ts) in frames: | |
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) | |
messages[0]["content"].append({"type": "image", "image": image}) | |
# Prepare final prompt for the model | |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Qwen requires images in the same order. We'll just collect them: | |
frame_images = [img for (img, _) in frames] | |
inputs = processor( | |
text=[prompt], | |
images=frame_images, | |
return_tensors="pt", | |
padding=True | |
).to("cuda") | |
# 3.3: Generate text output | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
# We'll run generation in a thread to simulate streaming. | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Collect the streamed text | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
# Sleep briefly to yield control | |
time.sleep(0.01) | |
# 3.4: Dummy bar chart for demonstration | |
fig, ax = plt.subplots() | |
categories = ["Category A", "Category B", "Category C"] | |
values = [random.randint(1, 10) for _ in categories] | |
ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]) | |
ax.set_title("Example Analytics Chart") | |
ax.set_ylabel("Value") | |
ax.set_xlabel("Category") | |
# Return text + figure | |
return generated_text, fig | |
##################################### | |
# 4. Build a Professional Gradio UI | |
##################################### | |
def build_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# **Qwen2.5-VL-7B-Instruct Live Video Analysis** | |
Record your webcam for a chosen duration, then click **Stop** to finalize. | |
After that, click **Analyze** to run Qwen2.5-VL and see textual + chart outputs. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
duration = gr.Radio( | |
choices=["5", "10", "20", "30"], | |
value="5", | |
label="Suggested Recording Duration (seconds)", | |
info="Select how long you plan to record before pressing Stop." | |
) | |
video = gr.Video( | |
source="webcam", | |
format="mp4", | |
label="Webcam Recording (press the Record button, then Stop)" | |
) | |
analyze_btn = gr.Button("Analyze", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Model Output") | |
output_plot = gr.Plot(label="Analytics Chart") | |
analyze_btn.click( | |
fn=video_inference, | |
inputs=[video, duration], | |
outputs=[output_text, output_plot] | |
) | |
return demo | |
if __name__ == "__main__": | |
app = build_app() | |
app.launch(debug=True) |