Doc-VLMs-OCR / app.py
prithivMLmods's picture
Update app.py
d82ce04 verified
raw
history blame
5.61 kB
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
Qwen2VLForConditionalGeneration,
AutoProcessor,
)
import spaces
import time
# Load Model & Processor
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda")
model.eval()
# Helper Function: Downsample Video
def downsample_video(video_path, max_duration=10, num_frames=10):
"""
Downsamples the video to `num_frames` evenly spaced frames within the first `max_duration` seconds.
Returns a list of (PIL Image, timestamp) tuples.
"""
vidcap = cv2.VideoCapture(video_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
if fps <= 0 or total_frames <= 0:
vidcap.release()
return []
# Limit to first `max_duration` seconds
max_frames = min(int(fps * max_duration), total_frames)
frame_indices = np.linspace(0, max_frames - 1, num_frames, dtype=int)
frames = []
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
# Inference Function
@spaces.GPU
def video_inference(video_file):
"""
Processes the video file and generates a text description based on the first 10 seconds.
Returns the generated text.
"""
if video_file is None:
return "No video provided."
frames = downsample_video(video_file, max_duration=10, num_frames=10)
if not frames:
return "Could not read frames from video."
# Construct prompt
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Please describe what's happening in this video."}]
}
]
for (image, ts) in frames:
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
messages[0]["content"].append({"type": "image", "image": image})
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
frame_images = [img for (img, _) in frames]
inputs = processor(
text=[prompt],
images=frame_images,
return_tensors="pt",
padding=True
).to("cuda")
# Generate text with streaming
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
time.sleep(0.01)
return generated_text
# Button Toggle Function
def toggle_button(has_result):
"""
Returns visibility states for start_again_btn and start_btn based on has_result.
"""
if has_result:
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
# Build the Gradio App
def build_app():
with gr.Blocks() as demo:
gr.Markdown("""
# **Gemma-3 Live Video Analysis**
Press **Start** to record a short video clip (up to 10 seconds). Stop recording to see the analysis.
After the result, press **Start Again** to analyze another clip.
""")
# State to track if a result has been generated
has_result = gr.State(value=False)
with gr.Row():
with gr.Column():
video = gr.Video(
sources=["webcam"],
label="Webcam Recording",
format="mp4"
)
# Two buttons: one for Start, one for Start Again
start_btn = gr.Button("Start", visible=True)
start_again_btn = gr.Button("Start Again", visible=False)
with gr.Column():
output_text = gr.Textbox(label="Model Output")
# When video is recorded and stopped, process it
def process_video(video_file, has_result_state):
if video_file is None:
return "Please record a video.", has_result_state
result = video_inference(video_file)
return result, True
video.change(
fn=process_video,
inputs=[video, has_result],
outputs=[output_text, has_result]
)
# Update button visibility based on has_result
has_result.change(
fn=toggle_button,
inputs=has_result,
outputs=[start_again_btn, start_btn]
)
# Clicking either button resets the video and output
def reset_state():
return None, "", False
start_btn.click(
fn=reset_state,
inputs=None,
outputs=[video, output_text, has_result]
)
start_again_btn.click(
fn=reset_state,
inputs=None,
outputs=[video, output_text, has_result]
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(debug=True)