import os from typing import Optional import cv2 import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from tqdm import tqdm from gradio_image_prompter import ImagePrompter from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \ MASK_GENERATION_MODE, BOX_PROMPT_MODE, VIDEO_SEGMENTATION_MODE from utils.video import create_directory, generate_unique_name from sam2.build_sam import build_sam2_video_predictor MARKDOWN = """ # Segment Anything Model 2 🔥
GitHub Colab Roboflow YouTube
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable visual segmentation in both images and videos. **Video segmentation will be available soon.** """ EXAMPLES = [ ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None], ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None], ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None], ] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE) SCALE_FACTOR = 0.5 TARGET_DIRECTORY = "tmp" # creating video results directory create_directory(directory_path=TARGET_DIRECTORY) def on_mode_dropdown_change(text): return [ gr.Image(visible=text == MASK_GENERATION_MODE), ImagePrompter(visible=text == BOX_PROMPT_MODE), gr.Video(visible=text == VIDEO_SEGMENTATION_MODE), ImagePrompter(visible=text == VIDEO_SEGMENTATION_MODE), gr.Button(visible=text != VIDEO_SEGMENTATION_MODE), gr.Button(visible=text == VIDEO_SEGMENTATION_MODE), gr.Image(visible=text != VIDEO_SEGMENTATION_MODE), gr.Video(visible=text == VIDEO_SEGMENTATION_MODE) ] def on_video_input_change(video_input): if not video_input: return None frames_generator = sv.get_video_frames_generator(video_input) frame = next(frames_generator) frame = sv.scale_image(frame, SCALE_FACTOR) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) return {'image': frame, 'points': []} def process_image( checkpoint_dropdown, mode_dropdown, image_input, image_prompter_input ) -> Optional[Image.Image]: if mode_dropdown == BOX_PROMPT_MODE: image_input = image_prompter_input["image"] prompt = image_prompter_input["points"] if len(prompt) == 0: return image_input model = IMAGE_PREDICTORS[checkpoint_dropdown] image = np.array(image_input.convert("RGB")) box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt]) model.set_image(image) masks, _, _ = model.predict(box=box, multimask_output=False) # dirty fix; remove this later if len(masks.shape) == 4: masks = np.squeeze(masks) detections = sv.Detections( xyxy=sv.mask_to_xyxy(masks=masks), mask=masks.astype(bool) ) return MASK_ANNOTATOR.annotate(image_input, detections) if mode_dropdown == MASK_GENERATION_MODE: model = MASK_GENERATORS[checkpoint_dropdown] image = np.array(image_input.convert("RGB")) result = model.generate(image) detections = sv.Detections.from_sam(result) return MASK_ANNOTATOR.annotate(image_input, detections) def process_video( checkpoint_dropdown, mode_dropdown, video_input, video_prompter_input, progress=gr.Progress(track_tqdm=True) ) -> str: if mode_dropdown != VIDEO_SEGMENTATION_MODE: return str(video_input) name = generate_unique_name() frame_directory_path = os.path.join(TARGET_DIRECTORY, name) frames_sink = sv.ImageSink( target_dir_path=frame_directory_path, image_name_pattern="{:05d}.jpeg" ) video_info = sv.VideoInfo.from_video_path(video_input) frames_generator = sv.get_video_frames_generator(video_input) with frames_sink: for frame in tqdm( frames_generator, total=video_info.total_frames, desc="splitting video into frames" ): frame = sv.scale_image(frame, SCALE_FACTOR) frames_sink.save_image(frame) model = build_sam2_video_predictor( "sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt", device=DEVICE ) inference_state = model.init_state( video_path=frame_directory_path, offload_video_to_cpu=DEVICE == torch.device('cpu'), offload_state_to_cpu=DEVICE == torch.device('cpu'), ) prompt = video_prompter_input["points"] points = np.array([[x1, y1] for x1, y1, _, _, _, _ in prompt]) labels = np.ones(len(points)) _, object_ids, mask_logits = model.add_new_points( inference_state=inference_state, frame_idx=0, obj_id=1, points=points, labels=labels, ) del inference_state del model video_path = os.path.join(TARGET_DIRECTORY, f"{name}.mp4") return str(video_input) with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): checkpoint_dropdown_component = gr.Dropdown( choices=CHECKPOINT_NAMES, value=CHECKPOINT_NAMES[0], label="Checkpoint", info="Select a SAM2 checkpoint to use.", interactive=True ) mode_dropdown_component = gr.Dropdown( choices=MODE_NAMES, value=MODE_NAMES[0], label="Mode", info="Select a mode to use. `box prompt` if you want to generate masks for " "selected objects, `mask generation` if you want to generate masks " "for the whole image, and `video segmentation` if you want to track " "object on video.", interactive=True ) with gr.Row(): with gr.Column(): image_input_component = gr.Image( type='pil', label='Upload image', visible=False) image_prompter_input_component = ImagePrompter( type='pil', label='Prompt image') video_input_component = gr.Video( label='Step 1: Upload video', visible=False) video_prompter_input_component = ImagePrompter( type='pil', label='Step 2: Prompt frame', visible=False) submit_image_button_component = gr.Button( value='Submit', variant='primary') submit_video_button_component = gr.Button( value='Submit', variant='primary', visible=False) with gr.Column(): image_output_component = gr.Image(type='pil', label='Image output') video_output_component = gr.Video( label='Step 2: Video output', visible=False) with gr.Row(): gr.Examples( fn=process_image, examples=EXAMPLES, inputs=[ checkpoint_dropdown_component, mode_dropdown_component, image_input_component, image_prompter_input_component, ], outputs=[image_output_component], run_on_click=True ) mode_dropdown_component.change( on_mode_dropdown_change, inputs=[mode_dropdown_component], outputs=[ image_input_component, image_prompter_input_component, video_input_component, video_prompter_input_component, submit_image_button_component, submit_video_button_component, image_output_component, video_output_component ] ) video_input_component.change( fn=on_video_input_change, inputs=[video_input_component], outputs=[video_prompter_input_component] ) submit_image_button_component.click( fn=process_image, inputs=[ checkpoint_dropdown_component, mode_dropdown_component, image_input_component, image_prompter_input_component, ], outputs=[image_output_component] ) submit_video_button_component.click( fn=process_video, inputs=[ checkpoint_dropdown_component, mode_dropdown_component, video_input_component, video_prompter_input_component, ], outputs=[video_output_component] ) demo.launch(debug=False, show_error=True, max_threads=1)