# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import warnings from collections import OrderedDict import torch import torch.nn.functional as F from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames class SAM2VideoPredictor(SAM2Base): """The predictor class to handle user interactions and manage inference states.""" def __init__( self, fill_hole_area=0, # whether to apply non-overlapping constraints on the output object masks non_overlap_masks=False, # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) clear_non_cond_mem_around_input=False, # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, **kwargs, ): super().__init__(**kwargs) self.fill_hole_area = fill_hole_area self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond @torch.inference_mode() def init_state( self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, ): """Initialize an inference state.""" compute_device = self.device # device of the model images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, compute_device=compute_device, ) inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu # whether to offload the inference state to CPU memory # turning on this option saves the GPU memory at the cost of a lower tracking fps # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu # the original video height and width, used for resizing final output scores inference_state["video_height"] = video_height inference_state["video_width"] = video_width inference_state["device"] = compute_device if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: inference_state["storage_device"] = compute_device # inputs on each frame inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} # visual features on a small number of recently visited frames for quick interactions inference_state["cached_features"] = {} # values that don't change across frames (so we only need to hold one copy of them) inference_state["constants"] = {} # mapping between client-side object id and model-side object index inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame # to add clicks or mask (it's merged into "output_dict" before propagation starts) inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["frames_tracked_per_obj"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": """ Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. **kwargs: Additional arguments to pass to the model constructor. Returns: (SAM2VideoPredictor): The loaded model. """ from sam2.build_sam import build_sam2_video_predictor_hf sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) return sam_model def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) if obj_idx is not None: return obj_idx # We always allow adding new objects (including after tracking starts). allow_new_object = True if allow_new_object: # get the next object slot obj_idx = len(inference_state["obj_id_to_idx"]) inference_state["obj_id_to_idx"][obj_id] = obj_idx inference_state["obj_idx_to_id"][obj_idx] = obj_id inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) # set up input and output structures for this object inference_state["point_inputs_per_obj"][obj_idx] = {} inference_state["mask_inputs_per_obj"][obj_idx] = {} inference_state["output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } inference_state["temp_output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } inference_state["frames_tracked_per_obj"][obj_idx] = {} return obj_idx else: raise RuntimeError( f"Cannot add new object id {obj_id} after tracking starts. " f"All existing object ids: {inference_state['obj_ids']}. " f"Please call 'reset_state' to restart from scratch." ) def _obj_idx_to_id(self, inference_state, obj_idx): """Map model-side object index to client-side object id.""" return inference_state["obj_idx_to_id"][obj_idx] def _get_obj_num(self, inference_state): """Get the total number of unique object ids received so far in this session.""" return len(inference_state["obj_idx_to_id"]) @torch.inference_mode() def add_new_points_or_box( self, inference_state, frame_idx, obj_id, points=None, labels=None, clear_old_points=True, normalize_coords=True, box=None, ): """Add new points to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] if (points is not None) != (labels is not None): raise ValueError("points and labels must be provided together") if points is None and box is None: raise ValueError("at least one of points or box must be provided as input") if points is None: points = torch.zeros(0, 2, dtype=torch.float32) elif not isinstance(points, torch.Tensor): points = torch.tensor(points, dtype=torch.float32) if labels is None: labels = torch.zeros(0, dtype=torch.int32) elif not isinstance(labels, torch.Tensor): labels = torch.tensor(labels, dtype=torch.int32) if points.dim() == 2: points = points.unsqueeze(0) # add batch dimension if labels.dim() == 1: labels = labels.unsqueeze(0) # add batch dimension # If `box` is provided, we add it as the first two points with labels 2 and 3 # along with the user-provided points (consistent with how SAM 2 is trained). if box is not None: if not clear_old_points: raise ValueError( "cannot add box without clearing old points, since " "box prompt must be provided before any point prompt " "(please use clear_old_points=True instead)" ) if not isinstance(box, torch.Tensor): box = torch.tensor(box, dtype=torch.float32, device=points.device) box_coords = box.reshape(1, 2, 2) box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) box_labels = box_labels.reshape(1, 2) points = torch.cat([box_coords, points], dim=1) labels = torch.cat([box_labels, labels], dim=1) if normalize_coords: video_H = inference_state["video_height"] video_W = inference_state["video_width"] points = points / torch.tensor([video_W, video_H]).to(points.device) # scale the (normalized) coordinates by the model's internal image size points = points * self.image_size points = points.to(inference_state["device"]) labels = labels.to(inference_state["device"]) if not clear_old_points: point_inputs = point_inputs_per_frame.get(frame_idx, None) else: point_inputs = None point_inputs = concat_points(point_inputs, points, labels) point_inputs_per_frame[frame_idx] = point_inputs mask_inputs_per_frame.pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Get any previously predicted mask logits on this object and feed it along with # the new clicks into the SAM mask decoder. prev_sam_mask_logits = None # lookup temporary output dict first, which contains the most recent output # (if not found, then lookup conditioning and non-conditioning frame output) prev_out = obj_temp_output_dict[storage_key].get(frame_idx) if prev_out is None: prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) if prev_out is None: prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) if prev_out is not None and prev_out["pred_masks"] is not None: device = inference_state["device"] prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) current_out, _ = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, # run on the slice of a single object frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=is_init_cond_frame, point_inputs=point_inputs, mask_inputs=None, reverse=reverse, # Skip the memory encoder when adding clicks or mask. We execute the memory encoder # at the beginning of `propagate_in_video` (after user finalize their clicks). This # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, prev_sam_mask_logits=prev_sam_mask_logits, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks def add_new_points(self, *args, **kwargs): """Deprecated method. Please use `add_new_points_or_box` instead.""" return self.add_new_points_or_box(*args, **kwargs) @torch.inference_mode() def add_new_mask( self, inference_state, frame_idx, obj_id, mask, ): """Add new mask to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] if not isinstance(mask, torch.Tensor): mask = torch.tensor(mask, dtype=torch.bool) assert mask.dim() == 2 mask_H, mask_W = mask.shape mask_inputs_orig = mask[None, None] # add batch and channel dimension mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) # resize the mask if it doesn't match the model's image size if mask_H != self.image_size or mask_W != self.image_size: mask_inputs = torch.nn.functional.interpolate( mask_inputs_orig, size=(self.image_size, self.image_size), align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) mask_inputs = (mask_inputs >= 0.5).float() else: mask_inputs = mask_inputs_orig mask_inputs_per_frame[frame_idx] = mask_inputs point_inputs_per_frame.pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" current_out, _ = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, # run on the slice of a single object frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=is_init_cond_frame, point_inputs=None, mask_inputs=mask_inputs, reverse=reverse, # Skip the memory encoder when adding clicks or mask. We execute the memory encoder # at the beginning of `propagate_in_video` (after user finalize their clicks). This # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks def _get_orig_video_res_output(self, inference_state, any_res_masks): """ Resize the object scores to the original video resolution (video_res_masks) and apply non-overlapping constraints for final output. """ device = inference_state["device"] video_H = inference_state["video_height"] video_W = inference_state["video_width"] any_res_masks = any_res_masks.to(device, non_blocking=True) if any_res_masks.shape[-2:] == (video_H, video_W): video_res_masks = any_res_masks else: video_res_masks = torch.nn.functional.interpolate( any_res_masks, size=(video_H, video_W), mode="bilinear", align_corners=False, ) if self.non_overlap_masks: video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) return any_res_masks, video_res_masks def _consolidate_temp_output_across_obj( self, inference_state, frame_idx, is_cond, consolidate_at_video_res=False, ): """ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on a frame into a single output for all objects, including 1) fill any missing objects either from `output_dict_per_obj` (if they exist in `output_dict_per_obj` for this frame) or leave them as placeholder values (if they don't exist in `output_dict_per_obj` for this frame); 2) if specified, rerun memory encoder after apply non-overlapping constraints on the object scores. """ batch_size = self._get_obj_num(inference_state) storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Optionally, we allow consolidating the temporary outputs at the original # video resolution (to provide a better editing experience for mask prompts). if consolidate_at_video_res: consolidated_H = inference_state["video_height"] consolidated_W = inference_state["video_width"] consolidated_mask_key = "pred_masks_video_res" else: consolidated_H = consolidated_W = self.image_size // 4 consolidated_mask_key = "pred_masks" # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" # will be added when rerunning the memory encoder after applying non-overlapping # constraints to object scores. Its "pred_masks" are prefilled with a large # negative value (NO_OBJ_SCORE) to represent missing objects. consolidated_out = { consolidated_mask_key: torch.full( size=(batch_size, 1, consolidated_H, consolidated_W), fill_value=NO_OBJ_SCORE, dtype=torch.float32, device=inference_state["storage_device"], ), } for obj_idx in range(batch_size): obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] out = obj_temp_output_dict[storage_key].get(frame_idx, None) # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, # we fall back and look up its previous output in "output_dict_per_obj". # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in # "output_dict_per_obj" to find a previous output for this object. if out is None: out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) if out is None: out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) # If the object doesn't appear in "output_dict_per_obj" either, we skip it # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE # placeholder above) and set its object pointer to be a dummy pointer. if out is None: continue # Add the temporary object output mask to consolidated output mask obj_mask = out["pred_masks"] consolidated_pred_masks = consolidated_out[consolidated_mask_key] if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask else: # Resize first if temporary object mask has a different resolution resized_obj_mask = torch.nn.functional.interpolate( obj_mask, size=consolidated_pred_masks.shape[-2:], mode="bilinear", align_corners=False, ) consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask return consolidated_out @torch.inference_mode() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" # Check and make sure that every object has received input points or masks. batch_size = self._get_obj_num(inference_state) if batch_size == 0: raise RuntimeError( "No input points or masks are provided for any object; please add inputs first." ) # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # add them into "output_dict". for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] for is_cond in [False, True]: # Separately consolidate conditioning and non-conditioning temp outputs storage_key = ( "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" ) # Find all the frames that contain temporary outputs for any objects # (these should be the frames that have just received clicks for mask inputs # via `add_new_points_or_box` or `add_new_mask`) for frame_idx, out in obj_temp_output_dict[storage_key].items(): # Run memory encoder on the temporary outputs (if the memory feature is missing) if out["maskmem_features"] is None: high_res_masks = torch.nn.functional.interpolate( out["pred_masks"].to(inference_state["device"]), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) maskmem_features, maskmem_pos_enc = self._run_memory_encoder( inference_state=inference_state, frame_idx=frame_idx, batch_size=1, # run on the slice of a single object high_res_masks=high_res_masks, object_score_logits=out["object_score_logits"], # these frames are what the user interacted with is_mask_from_pts=True, ) out["maskmem_features"] = maskmem_features out["maskmem_pos_enc"] = maskmem_pos_enc obj_output_dict[storage_key][frame_idx] = out if self.clear_non_cond_mem_around_input: # clear non-conditioning memory of the surrounding frames self._clear_obj_non_cond_mem_around_input( inference_state, frame_idx, obj_idx ) # clear temporary outputs in `temp_output_dict_per_obj` obj_temp_output_dict[storage_key].clear() # check and make sure that every object has received input points or masks obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] if len(obj_output_dict["cond_frame_outputs"]) == 0: obj_id = self._obj_idx_to_id(inference_state, obj_idx) raise RuntimeError( f"No input points or masks are provided for object id {obj_id}; please add inputs first." ) # edge case: if an output is added to "cond_frame_outputs", we remove any prior # output on the same frame in "non_cond_frame_outputs" for frame_idx in obj_output_dict["cond_frame_outputs"]: obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) @torch.inference_mode() def propagate_in_video( self, inference_state, start_frame_idx=None, max_frame_num_to_track=None, reverse=False, ): """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) obj_ids = inference_state["obj_ids"] num_frames = inference_state["num_frames"] batch_size = self._get_obj_num(inference_state) # set start index, end index, and processing order if start_frame_idx is None: # default: start from the earliest frame with input points start_frame_idx = min( t for obj_output_dict in inference_state["output_dict_per_obj"].values() for t in obj_output_dict["cond_frame_outputs"] ) if max_frame_num_to_track is None: # default: track all the frames in the video max_frame_num_to_track = num_frames if reverse: end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) if start_frame_idx > 0: processing_order = range(start_frame_idx, end_frame_idx - 1, -1) else: processing_order = [] # skip reverse tracking if starting from frame 0 else: end_frame_idx = min( start_frame_idx + max_frame_num_to_track, num_frames - 1 ) processing_order = range(start_frame_idx, end_frame_idx + 1) for frame_idx in tqdm(processing_order, desc="propagate in video"): pred_masks_per_obj = [None] * batch_size for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] # We skip those frames already in consolidated outputs (these are frames # that received input clicks or mask). Note that we cannot directly run # batched forward on them via `_run_single_frame_inference` because the # number of clicks on each object might be different. if frame_idx in obj_output_dict["cond_frame_outputs"]: storage_key = "cond_frame_outputs" current_out = obj_output_dict[storage_key][frame_idx] device = inference_state["device"] pred_masks = current_out["pred_masks"].to(device, non_blocking=True) if self.clear_non_cond_mem_around_input: # clear non-conditioning memory of the surrounding frames self._clear_obj_non_cond_mem_around_input( inference_state, frame_idx, obj_idx ) else: storage_key = "non_cond_frame_outputs" current_out, pred_masks = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=False, point_inputs=None, mask_inputs=None, reverse=reverse, run_mem_encoder=True, ) obj_output_dict[storage_key][frame_idx] = current_out inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { "reverse": reverse } pred_masks_per_obj[obj_idx] = pred_masks # Resize the output mask to the original video resolution (we directly use # the mask scores on GPU for output to avoid any CPU conversion in between) if len(pred_masks_per_obj) > 1: all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) else: all_pred_masks = pred_masks_per_obj[0] _, video_res_masks = self._get_orig_video_res_output( inference_state, all_pred_masks ) yield frame_idx, obj_ids, video_res_masks @torch.inference_mode() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True ): """Remove all input points or mask in a specific frame for a given object.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) # Clear the conditioning information on the given frame inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) if out is not None: # The frame is not a conditioning frame anymore since it's not receiving inputs, # so we "downgrade" its output (if exists) to a non-conditioning frame output. obj_output_dict["non_cond_frame_outputs"][frame_idx] = out inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None) if not need_output: return # Finally, output updated masks per object (after removing the inputs above) obj_ids = inference_state["obj_ids"] is_cond = any( frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values() ) consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks @torch.inference_mode() def reset_state(self, inference_state): """Remove all input points or mask in all frames throughout the video.""" self._reset_tracking_results(inference_state) # Remove all object ids inference_state["obj_id_to_idx"].clear() inference_state["obj_idx_to_id"].clear() inference_state["obj_ids"].clear() inference_state["point_inputs_per_obj"].clear() inference_state["mask_inputs_per_obj"].clear() inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() inference_state["frames_tracked_per_obj"].clear() def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" for v in inference_state["point_inputs_per_obj"].values(): v.clear() for v in inference_state["mask_inputs_per_obj"].values(): v.clear() for v in inference_state["output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() for v in inference_state["temp_output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() for v in inference_state["frames_tracked_per_obj"].values(): v.clear() def _get_image_feature(self, inference_state, frame_idx, batch_size): """Compute the image features on a given frame.""" # Look up in the cache first image, backbone_out = inference_state["cached_features"].get( frame_idx, (None, None) ) if backbone_out is None: # Cache miss -- we will run inference on a single image device = inference_state["device"] image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) backbone_out = self.forward_image(image) # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). inference_state["cached_features"] = {frame_idx: (image, backbone_out)} # expand the features to have the same dimension as the number of objects expanded_image = image.expand(batch_size, -1, -1, -1) expanded_backbone_out = { "backbone_fpn": backbone_out["backbone_fpn"].copy(), "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), } for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): expanded_backbone_out["backbone_fpn"][i] = feat.expand( batch_size, -1, -1, -1 ) for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): pos = pos.expand(batch_size, -1, -1, -1) expanded_backbone_out["vision_pos_enc"][i] = pos features = self._prepare_backbone_features(expanded_backbone_out) features = (expanded_image,) + features return features def _run_single_frame_inference( self, inference_state, output_dict, frame_idx, batch_size, is_init_cond_frame, point_inputs, mask_inputs, reverse, run_mem_encoder, prev_sam_mask_logits=None, ): """Run tracking on a single frame based on current inputs and previous memory.""" # Retrieve correct image features ( _, _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = self._get_image_feature(inference_state, frame_idx, batch_size) # point and mask should not appear as input simultaneously on the same frame assert point_inputs is None or mask_inputs is None current_out = self.track_step( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, feat_sizes=feat_sizes, point_inputs=point_inputs, mask_inputs=mask_inputs, output_dict=output_dict, num_frames=inference_state["num_frames"], track_in_reverse=reverse, run_mem_encoder=run_mem_encoder, prev_sam_mask_logits=prev_sam_mask_logits, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = current_out["maskmem_features"] if maskmem_features is not None: maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) pred_masks_gpu = current_out["pred_masks"] # potentially fill holes in the predicted masks if self.fill_hole_area > 0: pred_masks_gpu = fill_holes_in_mask_scores( pred_masks_gpu, self.fill_hole_area ) pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) # object pointer is a small tensor, so we always keep it on GPU memory for fast access obj_ptr = current_out["obj_ptr"] object_score_logits = current_out["object_score_logits"] # make a compact version of this frame's output to reduce the state size compact_current_out = { "maskmem_features": maskmem_features, "maskmem_pos_enc": maskmem_pos_enc, "pred_masks": pred_masks, "obj_ptr": obj_ptr, "object_score_logits": object_score_logits, } return compact_current_out, pred_masks_gpu def _run_memory_encoder( self, inference_state, frame_idx, batch_size, high_res_masks, object_score_logits, is_mask_from_pts, ): """ Run the memory encoder on `high_res_masks`. This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their memory also need to be computed again with the memory encoder. """ # Retrieve correct image features _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( inference_state, frame_idx, batch_size ) maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks, object_score_logits=object_score_logits, is_mask_from_pts=is_mask_from_pts, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc( inference_state, {"maskmem_pos_enc": maskmem_pos_enc} ) return maskmem_features, maskmem_pos_enc def _get_maskmem_pos_enc(self, inference_state, current_out): """ `maskmem_pos_enc` is the same across frames and objects, so we cache it as a constant in the inference session to reduce session storage size. """ model_constants = inference_state["constants"] # "out_maskmem_pos_enc" should be either a list of tensors or None out_maskmem_pos_enc = current_out["maskmem_pos_enc"] if out_maskmem_pos_enc is not None: if "maskmem_pos_enc" not in model_constants: assert isinstance(out_maskmem_pos_enc, list) # only take the slice for one object, since it's same across objects maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] model_constants["maskmem_pos_enc"] = maskmem_pos_enc else: maskmem_pos_enc = model_constants["maskmem_pos_enc"] # expand the cached maskmem_pos_enc to the actual batch size batch_size = out_maskmem_pos_enc[0].size(0) expanded_maskmem_pos_enc = [ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc ] else: expanded_maskmem_pos_enc = None return expanded_maskmem_pos_enc @torch.inference_mode() def remove_object(self, inference_state, obj_id, strict=False, need_output=True): """ Remove an object id from the tracking state. If strict is True, we check whether the object id actually exists and raise an error if it doesn't exist. """ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) updated_frames = [] # Check whether this object_id to remove actually exists and possibly raise an error. if old_obj_idx_to_rm is None: if not strict: return inference_state["obj_ids"], updated_frames raise RuntimeError( f"Cannot remove object id {obj_id} as it doesn't exist. " f"All existing object ids: {inference_state['obj_ids']}." ) # If this is the only remaining object id, we simply reset the state. if len(inference_state["obj_id_to_idx"]) == 1: self.reset_state(inference_state) return inference_state["obj_ids"], updated_frames # There are still remaining objects after removing this object id. In this case, # we need to delete the object storage from inference state tensors. # Step 0: clear the input on those frames where this object id has point or mask input # (note that this step is required as it might downgrade conditioning frames to # non-conditioning ones) obj_input_frames_inds = set() obj_input_frames_inds.update( inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] ) obj_input_frames_inds.update( inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] ) for frame_idx in obj_input_frames_inds: self.clear_all_prompts_in_frame( inference_state, frame_idx, obj_id, need_output=False ) # Step 1: Update the object id mapping (note that it must be done after Step 0, # since Step 0 still requires the old object id mappings in inference_state) old_obj_ids = inference_state["obj_ids"] old_obj_inds = list(range(len(old_obj_ids))) remain_old_obj_inds = old_obj_inds.copy() remain_old_obj_inds.remove(old_obj_idx_to_rm) new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] new_obj_inds = list(range(len(new_obj_ids))) # build new mappings old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) inference_state["obj_ids"] = new_obj_ids # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. def _map_keys(container): new_kvs = [] for k in old_obj_inds: v = container.pop(k) if k in old_idx_to_new_idx: new_kvs.append((old_idx_to_new_idx[k], v)) container.update(new_kvs) _map_keys(inference_state["point_inputs_per_obj"]) _map_keys(inference_state["mask_inputs_per_obj"]) _map_keys(inference_state["output_dict_per_obj"]) _map_keys(inference_state["temp_output_dict_per_obj"]) _map_keys(inference_state["frames_tracked_per_obj"]) # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which # could show an updated mask for objects previously occluded by the object being removed if need_output: temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] for frame_idx in obj_input_frames_inds: is_cond = any( frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values() ) consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) updated_frames.append((frame_idx, video_res_masks)) return inference_state["obj_ids"], updated_frames def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): """ Remove the non-conditioning memory around the input frame. When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated object appearance information and could confuse the model. This method clears those non-conditioning memories surrounding the interacted frame to avoid giving the model both old and new information about the object. """ r = self.memory_temporal_stride_for_eval frame_idx_begin = frame_idx - r * self.num_maskmem frame_idx_end = frame_idx + r * self.num_maskmem batch_size = self._get_obj_num(inference_state) for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] for t in range(frame_idx_begin, frame_idx_end + 1): non_cond_frame_outputs.pop(t, None) class SAM2VideoPredictorVOS(SAM2VideoPredictor): """Optimized for the VOS setting""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._compile_all_components() def _compile_all_components(self): print("Compiling all components for VOS setting. First time may be very slow.") self.memory_encoder.forward = torch.compile( self.memory_encoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, ) self.memory_attention.forward = torch.compile( self.memory_attention.forward, mode="max-autotune", fullgraph=True, dynamic=True, # Num. of memories varies ) self.sam_prompt_encoder.forward = torch.compile( self.sam_prompt_encoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, # Accuracy regression on True ) self.sam_mask_decoder.forward = torch.compile( self.sam_mask_decoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, # Accuracy regression on True ) def forward_image(self, img_batch: torch.Tensor): """ Identical to the corresponding method in the parent (SAM2VideoPredictor), but cloning the backbone features and pos encoding to enable compilation. """ backbone_out = self.image_encoder(img_batch) if self.use_high_res_features_in_sam: # precompute projected level 0 and level 1 features in SAM decoder # to avoid running it again on every SAM click backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( backbone_out["backbone_fpn"][0] ) backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( backbone_out["backbone_fpn"][1] ) # Clone to help torch.compile for i in range(len(backbone_out["backbone_fpn"])): backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ i ].clone() return backbone_out def _forward_sam_heads( self, backbone_features, point_inputs=None, mask_inputs=None, high_res_features=None, multimask_output=False, ): """ Identical to the corresponding method in the parent (SAM2VideoPredictor), but cloning the outputs of prompt_encoder and mask_decoder to enable compilation. """ B = backbone_features.size(0) device = backbone_features.device assert backbone_features.size(1) == self.sam_prompt_embed_dim assert backbone_features.size(2) == self.sam_image_embedding_size assert backbone_features.size(3) == self.sam_image_embedding_size # a) Handle point prompts if point_inputs is not None: sam_point_coords = point_inputs["point_coords"] sam_point_labels = point_inputs["point_labels"] assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B else: # If no points are provide, pad with an empty point (with label -1) sam_point_coords = torch.zeros(B, 1, 2, device=device) sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) # b) Handle mask prompts if mask_inputs is not None: # If mask_inputs is provided, downsize it into low-res mask input if needed # and feed it as a dense mask prompt into the SAM mask encoder assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: sam_mask_prompt = F.interpolate( mask_inputs.float(), size=self.sam_prompt_encoder.mask_input_size, align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) else: sam_mask_prompt = mask_inputs else: # Otherwise, simply feed None (and SAM's prompt encoder will add # a learned `no_mask_embed` to indicate no mask input in this case). sam_mask_prompt = None sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( points=(sam_point_coords, sam_point_labels), boxes=None, masks=sam_mask_prompt, ) # Clone image_pe and the outputs of sam_prompt_encoder # to enable compilation sparse_embeddings = sparse_embeddings.clone() dense_embeddings = dense_embeddings.clone() image_pe = self.sam_prompt_encoder.get_dense_pe().clone() ( low_res_multimasks, ious, sam_output_tokens, object_score_logits, ) = self.sam_mask_decoder( image_embeddings=backbone_features, image_pe=image_pe, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, repeat_image=False, # the image is already batched high_res_features=high_res_features, ) # Clone the output of sam_mask_decoder # to enable compilation low_res_multimasks = low_res_multimasks.clone() ious = ious.clone() sam_output_tokens = sam_output_tokens.clone() object_score_logits = object_score_logits.clone() if self.pred_obj_scores: is_obj_appearing = object_score_logits > 0 # Mask used for spatial memories is always a *hard* choice between obj and no obj, # consistent with the actual mask prediction low_res_multimasks = torch.where( is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE, ) # convert masks from possibly bfloat16 (or float16) to float32 # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) low_res_multimasks = low_res_multimasks.float() high_res_multimasks = F.interpolate( low_res_multimasks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) sam_output_token = sam_output_tokens[:, 0] if multimask_output: # take the best mask prediction (with the highest IoU estimation) best_iou_inds = torch.argmax(ious, dim=-1) batch_inds = torch.arange(B, device=device) low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) if sam_output_tokens.size(1) > 1: sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] else: low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks # Extract object pointer from the SAM output token (with occlusion handling) obj_ptr = self.obj_ptr_proj(sam_output_token) if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() if self.fixed_no_obj_ptr: obj_ptr = lambda_is_obj_appearing * obj_ptr obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr return ( low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) def _encode_new_memory( self, current_vision_feats, feat_sizes, pred_masks_high_res, object_score_logits, is_mask_from_pts, ): """ Identical to the corresponding method in the parent (SAM2VideoPredictor), but cloning the memories and their pos enc to enable compilation. """ B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size # top-level feature, (HW)BC => BCHW pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) if self.non_overlap_masks_for_mem_enc and not self.training: # optionally, apply non-overlapping constraints to the masks (it's applied # in the batch dimension and should only be used during eval, where all # the objects come from the same video under batch size 1). pred_masks_high_res = self._apply_non_overlapping_constraints( pred_masks_high_res ) # scale the raw mask logits with a temperature before applying sigmoid binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts if binarize and not self.training: mask_for_mem = (pred_masks_high_res > 0).float() else: # apply sigmoid on the raw mask logits to turn them into range (0, 1) mask_for_mem = torch.sigmoid(pred_masks_high_res) # apply scale and bias terms to the sigmoid probabilities if self.sigmoid_scale_for_mem_enc != 1.0: mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc if self.sigmoid_bias_for_mem_enc != 0.0: mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc maskmem_out = self.memory_encoder( pix_feat, mask_for_mem, skip_mask_sigmoid=True, # sigmoid already applied ) # Clone the feats and pos_enc to enable compilation maskmem_features = maskmem_out["vision_features"].clone() maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] # add a no-object embedding to the spatial memory to indicate that the frame # is predicted to be occluded (i.e. no object is appearing in the frame) if self.no_obj_embed_spatial is not None: is_obj_appearing = (object_score_logits > 0).float() maskmem_features += ( 1 - is_obj_appearing[..., None, None] ) * self.no_obj_embed_spatial[..., None, None].expand( *maskmem_features.shape ) return maskmem_features, maskmem_pos_enc