Doc-VLMs / app.py
prithivMLmods's picture
Update app.py
8716c2f verified
raw
history blame
5.72 kB
import gradio as gr
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
import time
from PIL import Image
from threading import Thread
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
#####################################
# 1. Load Qwen2.5-VL Model & Processor
#####################################
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # or "Qwen/Qwen2.5-VL-3B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda")
model.eval()
#####################################
# 2. Helper Function: Downsample Video
#####################################
def downsample_video(video_path, num_frames=10):
"""
Downsamples the video file to `num_frames` evenly spaced frames.
Each frame is converted to a PIL Image along with its timestamp.
"""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
if total_frames <= 0 or fps <= 0:
vidcap.release()
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
#####################################
# 3. The Inference Function
#####################################
def video_inference(video_file, duration):
"""
- Takes a recorded video file and a chosen duration (string).
- Downsamples the video, passes frames to Qwen2.5-VL for inference.
- Returns model-generated text + a dummy bar chart as example analytics.
"""
if video_file is None:
return "No video provided.", None
# 3.1: Downsample the recorded video
frames = downsample_video(video_file)
if not frames:
return "Could not read frames from video.", None
# 3.2: Construct Qwen2.5-VL prompt
# We'll do a simple prompt: "Please describe what's happening in this video."
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Please describe what's happening in this video."}]
}
]
# Add frames (with timestamp) to the messages
for (image, ts) in frames:
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
messages[0]["content"].append({"type": "image", "image": image})
# Prepare final prompt for the model
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Qwen requires images in the same order. We'll just collect them:
frame_images = [img for (img, _) in frames]
inputs = processor(
text=[prompt],
images=frame_images,
return_tensors="pt",
padding=True
).to("cuda")
# 3.3: Generate text output
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
# We'll run generation in a thread to simulate streaming.
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Collect the streamed text
generated_text = ""
for new_text in streamer:
generated_text += new_text
# Sleep briefly to yield control
time.sleep(0.01)
# 3.4: Dummy bar chart for demonstration
fig, ax = plt.subplots()
categories = ["Category A", "Category B", "Category C"]
values = [random.randint(1, 10) for _ in categories]
ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"])
ax.set_title("Example Analytics Chart")
ax.set_ylabel("Value")
ax.set_xlabel("Category")
# Return text + figure
return generated_text, fig
#####################################
# 4. Build a Professional Gradio UI
#####################################
def build_app():
with gr.Blocks() as demo:
gr.Markdown("""
# **Qwen2.5-VL-7B-Instruct Live Video Analysis**
Record your webcam for a chosen duration, then click **Stop** to finalize.
After that, click **Analyze** to run Qwen2.5-VL and see textual + chart outputs.
""")
with gr.Row():
with gr.Column():
duration = gr.Radio(
choices=["5", "10", "20", "30"],
value="5",
label="Suggested Recording Duration (seconds)",
info="Select how long you plan to record before pressing Stop."
)
video = gr.Video(
source="webcam",
format="mp4",
label="Webcam Recording (press the Record button, then Stop)"
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Model Output")
output_plot = gr.Plot(label="Analytics Chart")
analyze_btn.click(
fn=video_inference,
inputs=[video, duration],
outputs=[output_text, output_plot]
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(debug=True)