SkalskiP's picture
working on video inference
aabd771
raw
history blame
9.48 kB
import os
from typing import Optional
import cv2
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from tqdm import tqdm
from gradio_image_prompter import ImagePrompter
from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
MASK_GENERATION_MODE, BOX_PROMPT_MODE, VIDEO_SEGMENTATION_MODE
from utils.video import create_directory, generate_unique_name
from sam2.build_sam import build_sam2_video_predictor
MARKDOWN = """
# Segment Anything Model 2 🔥
<div>
<a href="https://github.com/facebookresearch/segment-anything-2">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/what-is-segment-anything-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
</a>
</div>
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
visual segmentation in both images and videos. **Video segmentation will be available
soon.**
"""
EXAMPLES = [
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None],
]
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
SCALE_FACTOR = 0.5
TARGET_DIRECTORY = "tmp"
# creating video results directory
create_directory(directory_path=TARGET_DIRECTORY)
def on_mode_dropdown_change(text):
return [
gr.Image(visible=text == MASK_GENERATION_MODE),
ImagePrompter(visible=text == BOX_PROMPT_MODE),
gr.Video(visible=text == VIDEO_SEGMENTATION_MODE),
ImagePrompter(visible=text == VIDEO_SEGMENTATION_MODE),
gr.Button(visible=text != VIDEO_SEGMENTATION_MODE),
gr.Button(visible=text == VIDEO_SEGMENTATION_MODE),
gr.Image(visible=text != VIDEO_SEGMENTATION_MODE),
gr.Video(visible=text == VIDEO_SEGMENTATION_MODE)
]
def on_video_input_change(video_input):
if not video_input:
return None
frames_generator = sv.get_video_frames_generator(video_input)
frame = next(frames_generator)
frame = sv.scale_image(frame, SCALE_FACTOR)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
return {'image': frame, 'points': []}
def process_image(
checkpoint_dropdown,
mode_dropdown,
image_input,
image_prompter_input
) -> Optional[Image.Image]:
if mode_dropdown == BOX_PROMPT_MODE:
image_input = image_prompter_input["image"]
prompt = image_prompter_input["points"]
if len(prompt) == 0:
return image_input
model = IMAGE_PREDICTORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt])
model.set_image(image)
masks, _, _ = model.predict(box=box, multimask_output=False)
# dirty fix; remove this later
if len(masks.shape) == 4:
masks = np.squeeze(masks)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks.astype(bool)
)
return MASK_ANNOTATOR.annotate(image_input, detections)
if mode_dropdown == MASK_GENERATION_MODE:
model = MASK_GENERATORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
result = model.generate(image)
detections = sv.Detections.from_sam(result)
return MASK_ANNOTATOR.annotate(image_input, detections)
def process_video(
checkpoint_dropdown,
mode_dropdown,
video_input,
video_prompter_input,
progress=gr.Progress(track_tqdm=True)
) -> str:
if mode_dropdown != VIDEO_SEGMENTATION_MODE:
return str(video_input)
name = generate_unique_name()
frame_directory_path = os.path.join(TARGET_DIRECTORY, name)
frames_sink = sv.ImageSink(
target_dir_path=frame_directory_path,
image_name_pattern="{:05d}.jpeg"
)
video_info = sv.VideoInfo.from_video_path(video_input)
frames_generator = sv.get_video_frames_generator(video_input)
with frames_sink:
for frame in tqdm(
frames_generator,
total=video_info.total_frames,
desc="splitting video into frames"
):
frame = sv.scale_image(frame, SCALE_FACTOR)
frames_sink.save_image(frame)
model = build_sam2_video_predictor(
"sam2_hiera_t.yaml",
"checkpoints/sam2_hiera_tiny.pt",
device=DEVICE
)
inference_state = model.init_state(
video_path=frame_directory_path,
offload_video_to_cpu=DEVICE == torch.device('cpu'),
offload_state_to_cpu=DEVICE == torch.device('cpu'),
)
prompt = video_prompter_input["points"]
points = np.array([[x1, y1] for x1, y1, _, _, _, _ in prompt])
labels = np.ones(len(points))
_, object_ids, mask_logits = model.add_new_points(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
points=points,
labels=labels,
)
del inference_state
del model
video_path = os.path.join(TARGET_DIRECTORY, f"{name}.mp4")
return str(video_input)
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
checkpoint_dropdown_component = gr.Dropdown(
choices=CHECKPOINT_NAMES,
value=CHECKPOINT_NAMES[0],
label="Checkpoint", info="Select a SAM2 checkpoint to use.",
interactive=True
)
mode_dropdown_component = gr.Dropdown(
choices=MODE_NAMES,
value=MODE_NAMES[0],
label="Mode",
info="Select a mode to use. `box prompt` if you want to generate masks for "
"selected objects, `mask generation` if you want to generate masks "
"for the whole image, and `video segmentation` if you want to track "
"object on video.",
interactive=True
)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image', visible=False)
image_prompter_input_component = ImagePrompter(
type='pil', label='Prompt image')
video_input_component = gr.Video(
label='Step 1: Upload video', visible=False)
video_prompter_input_component = ImagePrompter(
type='pil', label='Step 2: Prompt frame', visible=False)
submit_image_button_component = gr.Button(
value='Submit', variant='primary')
submit_video_button_component = gr.Button(
value='Submit', variant='primary', visible=False)
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image output')
video_output_component = gr.Video(
label='Step 2: Video output', visible=False)
with gr.Row():
gr.Examples(
fn=process_image,
examples=EXAMPLES,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
image_input_component,
image_prompter_input_component,
],
outputs=[image_output_component],
run_on_click=True
)
mode_dropdown_component.change(
on_mode_dropdown_change,
inputs=[mode_dropdown_component],
outputs=[
image_input_component,
image_prompter_input_component,
video_input_component,
video_prompter_input_component,
submit_image_button_component,
submit_video_button_component,
image_output_component,
video_output_component
]
)
video_input_component.change(
fn=on_video_input_change,
inputs=[video_input_component],
outputs=[video_prompter_input_component]
)
submit_image_button_component.click(
fn=process_image,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
image_input_component,
image_prompter_input_component,
],
outputs=[image_output_component]
)
submit_video_button_component.click(
fn=process_video,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
video_input_component,
video_prompter_input_component,
],
outputs=[video_output_component]
)
demo.launch(debug=False, show_error=True, max_threads=1)