Luigi Piccinelli
init demo
1ea89dd
from collections.abc import Sequence
import numpy as np
import torch
class Collect(object):
def __init__(
self,
keys,
meta_keys=(
"filename",
"keyframe_idx",
"sequence_name",
"image_filename",
"depth_filename",
"image_ori_shape",
"camera",
"original_camera",
"sfm",
"image_shape",
"resized_shape",
"scale_factor",
"rotation",
"resize_factor",
"flip",
"flip_direction",
"dataset_name",
"paddings",
"max_value",
"log_mean",
"log_std",
"image_rescale",
"focal_rescale",
"depth_rescale",
),
):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data_keys = [key for field in self.keys for key in results.get(field, [])]
data = {
key: {
sequence_key: results[key][sequence_key]
for sequence_key in results["sequence_fields"]
}
for key in data_keys
}
data["img_metas"] = {
key: value for key, value in results.items() if key not in data_keys
}
return data
def __repr__(self):
return (
self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})"
)
class AnnotationMask(object):
def __init__(self, min_value, max_value, custom_fn=lambda x: x):
self.min_value = min_value
self.max_value = max_value
self.custom_fn = custom_fn
def __call__(self, results):
for key in results.get("gt_fields", []):
if key + "_mask" in results["mask_fields"]:
if "flow" in key:
for sequence_idx in results.get("sequence_fields", []):
boundaries = (results[key][sequence_idx] >= -1) & (
results[key][sequence_idx] <= 1
)
boundaries = boundaries[:, :1] & boundaries[:, 1:]
results[key + "_mask"][sequence_idx] = (
results[key + "_mask"][sequence_idx].bool() & boundaries
)
continue
for sequence_idx in results.get("sequence_fields", []):
# take care of xyz or flow, dim=1 is the channel dim
if results[key][sequence_idx].shape[1] == 1:
mask = results[key][sequence_idx] > self.min_value
else:
mask = (
results[key][sequence_idx].norm(dim=1, keepdim=True)
> self.min_value
)
if self.max_value is not None:
if results[key][sequence_idx].shape[1] == 1:
mask = mask & (results[key][sequence_idx] < self.max_value)
else:
mask = mask & (
results[key][sequence_idx].norm(dim=1, keepdim=True)
< self.max_value
)
mask = self.custom_fn(mask, info=results)
if key + "_mask" not in results:
results[key + "_mask"] = {}
if sequence_idx not in results[key + "_mask"]:
results[key + "_mask"][sequence_idx] = mask.bool()
else:
results[key + "_mask"][sequence_idx] = (
results[key + "_mask"][sequence_idx].bool() & mask.bool()
)
results["mask_fields"].add(key + "_mask")
return results
def __repr__(self):
return (
self.__class__.__name__
+ f"(min_value={self.min_value}, max_value={ self.max_value})"
)