from collections.abc import Sequence import numpy as np import torch class Collect(object): def __init__( self, keys, meta_keys=( "filename", "keyframe_idx", "sequence_name", "image_filename", "depth_filename", "image_ori_shape", "camera", "original_camera", "sfm", "image_shape", "resized_shape", "scale_factor", "rotation", "resize_factor", "flip", "flip_direction", "dataset_name", "paddings", "max_value", "log_mean", "log_std", "image_rescale", "focal_rescale", "depth_rescale", ), ): self.keys = keys self.meta_keys = meta_keys def __call__(self, results): data_keys = [key for field in self.keys for key in results.get(field, [])] data = { key: { sequence_key: results[key][sequence_key] for sequence_key in results["sequence_fields"] } for key in data_keys } data["img_metas"] = { key: value for key, value in results.items() if key not in data_keys } return data def __repr__(self): return ( self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" ) class AnnotationMask(object): def __init__(self, min_value, max_value, custom_fn=lambda x: x): self.min_value = min_value self.max_value = max_value self.custom_fn = custom_fn def __call__(self, results): for key in results.get("gt_fields", []): if key + "_mask" in results["mask_fields"]: if "flow" in key: for sequence_idx in results.get("sequence_fields", []): boundaries = (results[key][sequence_idx] >= -1) & ( results[key][sequence_idx] <= 1 ) boundaries = boundaries[:, :1] & boundaries[:, 1:] results[key + "_mask"][sequence_idx] = ( results[key + "_mask"][sequence_idx].bool() & boundaries ) continue for sequence_idx in results.get("sequence_fields", []): # take care of xyz or flow, dim=1 is the channel dim if results[key][sequence_idx].shape[1] == 1: mask = results[key][sequence_idx] > self.min_value else: mask = ( results[key][sequence_idx].norm(dim=1, keepdim=True) > self.min_value ) if self.max_value is not None: if results[key][sequence_idx].shape[1] == 1: mask = mask & (results[key][sequence_idx] < self.max_value) else: mask = mask & ( results[key][sequence_idx].norm(dim=1, keepdim=True) < self.max_value ) mask = self.custom_fn(mask, info=results) if key + "_mask" not in results: results[key + "_mask"] = {} if sequence_idx not in results[key + "_mask"]: results[key + "_mask"][sequence_idx] = mask.bool() else: results[key + "_mask"][sequence_idx] = ( results[key + "_mask"][sequence_idx].bool() & mask.bool() ) results["mask_fields"].add(key + "_mask") return results def __repr__(self): return ( self.__class__.__name__ + f"(min_value={self.min_value}, max_value={ self.max_value})" )