Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from sam2.build_sam import build_sam2_video_predictor, build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
import cv2 | |
import traceback | |
import matplotlib.pyplot as plt | |
# CUDA optimizations | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Initialize models | |
sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt" | |
model_cfg = "sam2_hiera_l.yaml" | |
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) | |
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") | |
image_predictor = SAM2ImagePredictor(sam2_model) | |
model_id = 'microsoft/Florence-2-large' | |
florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16).eval().cuda() | |
florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
def apply_color_mask(frame, mask, obj_id): | |
cmap = plt.get_cmap("tab10") | |
color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors | |
# Ensure mask has the correct shape | |
if mask.ndim == 4: | |
mask = mask.squeeze() # Remove singleton dimensions | |
if mask.ndim == 3 and mask.shape[0] == 1: | |
mask = mask[0] # Take the first channel if it's a single-channel 3D array | |
# Reshape mask to match frame dimensions | |
mask = cv2.resize(mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR) | |
# Expand dimensions of mask and color for broadcasting | |
mask = np.expand_dims(mask, axis=2) | |
color = color.reshape(1, 1, 3) | |
colored_mask = mask * color | |
return frame * (1 - mask) + colored_mask * 255 | |
def run_florence(image, text_input): | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
task_prompt = '<OPEN_VOCABULARY_DETECTION>' | |
prompt = task_prompt + text_input | |
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16) | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"].cuda(), | |
pixel_values=inputs["pixel_values"].cuda(), | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = florence_processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
return parsed_answer[task_prompt]['bboxes'][0] | |
def remove_directory_contents(directory): | |
for root, dirs, files in os.walk(directory, topdown=False): | |
for name in files: | |
os.remove(os.path.join(root, name)) | |
for name in dirs: | |
os.rmdir(os.path.join(root, name)) | |
def process_video(video_path, prompt, chunk_size=30): | |
try: | |
video = cv2.VideoCapture(video_path) | |
if not video.isOpened(): | |
raise ValueError("Unable to open video file") | |
fps = video.get(cv2.CAP_PROP_FPS) | |
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Process video in chunks | |
all_segmented_frames = [] | |
for chunk_start in range(0, frame_count, chunk_size): | |
chunk_end = min(chunk_start + chunk_size, frame_count) | |
frames = [] | |
video.set(cv2.CAP_PROP_POS_FRAMES, chunk_start) | |
for _ in range(chunk_end - chunk_start): | |
ret, frame = video.read() | |
if not ret: | |
break | |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
if not frames: | |
print(f"No frames extracted for chunk starting at {chunk_start}") | |
continue | |
# Florence detection on first frame of the chunk | |
first_frame = Image.fromarray(frames[0]) | |
mask_box = run_florence(first_frame, prompt) | |
print("Original mask box:", mask_box) | |
# Convert mask_box to numpy array and ensure it's in the correct format | |
mask_box = np.array(mask_box) | |
print("Reshaped mask box:", mask_box) | |
# SAM2 segmentation on first frame | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
image_predictor.set_image(first_frame) | |
masks, _, _ = image_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=mask_box[None, :], | |
multimask_output=False, | |
) | |
print("masks.shape",masks.shape) | |
mask = masks.squeeze().astype(bool) | |
print("Mask shape:", mask.shape) | |
print("Frame shape:", frames[0].shape) | |
# SAM2 video propagation | |
temp_dir = f"temp_frames_{chunk_start}" | |
os.makedirs(temp_dir, exist_ok=True) | |
for i, frame in enumerate(frames): | |
cv2.imwrite(os.path.join(temp_dir, f"{i:04d}.jpg"), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
inference_state = video_predictor.init_state(video_path=temp_dir) | |
_, _, _ = video_predictor.add_new_mask( | |
inference_state=inference_state, | |
frame_idx=0, | |
obj_id=1, | |
mask=mask | |
) | |
video_segments = {} | |
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): | |
video_segments[out_frame_idx] = { | |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
print('segmenting for main vid done') | |
# Apply segmentation masks to frames | |
for i, frame in enumerate(frames): | |
if i in video_segments: | |
for out_obj_id, mask in video_segments[i].items(): | |
frame = apply_color_mask(frame, mask, out_obj_id) | |
all_segmented_frames.append(frame.astype(np.uint8)) | |
else: | |
all_segmented_frames.append(frame) | |
# Clean up temporary files | |
remove_directory_contents(temp_dir) | |
os.rmdir(temp_dir) | |
video.release() | |
if not all_segmented_frames: | |
raise ValueError("No frames were processed successfully") | |
# Create video from segmented frames | |
output_path = "segmented_video.mp4" | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, | |
(all_segmented_frames[0].shape[1], all_segmented_frames[0].shape[0])) | |
for frame in all_segmented_frames: | |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
out.release() | |
return output_path | |
except Exception as e: | |
print(f"Error in process_video: {str(e)}") | |
print(traceback.format_exc()) # This will print the full stack trace | |
return None | |
def segment_video(video_file, prompt, chunk_size): | |
if video_file is None: | |
return None | |
output_video = process_video(video_file, prompt, int(chunk_size)) | |
return output_video | |
demo = gr.Interface( | |
fn=segment_video, | |
inputs=[ | |
gr.Video(label="Upload Video"), | |
gr.Textbox(label="Enter prompt (e.g., 'a gymnast')"), | |
gr.Slider(minimum=10, maximum=100, step=10, value=30, label="Chunk Size (frames)") | |
], | |
outputs=gr.Video(label="Segmented Video"), | |
title="Video Object Segmentation with Florence and SAM2", | |
description="Upload a video and provide a text prompt to segment a specific object throughout the video." | |
) | |
demo.launch() |