File size: 3,979 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from collections.abc import Sequence

import numpy as np
import torch


class Collect(object):
    def __init__(
        self,
        keys,
        meta_keys=(
            "filename",
            "keyframe_idx",
            "sequence_name",
            "image_filename",
            "depth_filename",
            "image_ori_shape",
            "camera",
            "original_camera",
            "sfm",
            "image_shape",
            "resized_shape",
            "scale_factor",
            "rotation",
            "resize_factor",
            "flip",
            "flip_direction",
            "dataset_name",
            "paddings",
            "max_value",
            "log_mean",
            "log_std",
            "image_rescale",
            "focal_rescale",
            "depth_rescale",
        ),
    ):
        self.keys = keys
        self.meta_keys = meta_keys

    def __call__(self, results):
        data_keys = [key for field in self.keys for key in results.get(field, [])]
        data = {
            key: {
                sequence_key: results[key][sequence_key]
                for sequence_key in results["sequence_fields"]
            }
            for key in data_keys
        }
        data["img_metas"] = {
            key: value for key, value in results.items() if key not in data_keys
        }
        return data

    def __repr__(self):
        return (
            self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})"
        )


class AnnotationMask(object):
    def __init__(self, min_value, max_value, custom_fn=lambda x: x):
        self.min_value = min_value
        self.max_value = max_value
        self.custom_fn = custom_fn

    def __call__(self, results):
        for key in results.get("gt_fields", []):
            if key + "_mask" in results["mask_fields"]:
                if "flow" in key:
                    for sequence_idx in results.get("sequence_fields", []):
                        boundaries = (results[key][sequence_idx] >= -1) & (
                            results[key][sequence_idx] <= 1
                        )
                        boundaries = boundaries[:, :1] & boundaries[:, 1:]
                        results[key + "_mask"][sequence_idx] = (
                            results[key + "_mask"][sequence_idx].bool() & boundaries
                        )
                    continue
            for sequence_idx in results.get("sequence_fields", []):
                # take care of xyz or flow, dim=1 is the channel dim
                if results[key][sequence_idx].shape[1] == 1:
                    mask = results[key][sequence_idx] > self.min_value
                else:
                    mask = (
                        results[key][sequence_idx].norm(dim=1, keepdim=True)
                        > self.min_value
                    )
                if self.max_value is not None:
                    if results[key][sequence_idx].shape[1] == 1:
                        mask = mask & (results[key][sequence_idx] < self.max_value)
                    else:
                        mask = mask & (
                            results[key][sequence_idx].norm(dim=1, keepdim=True)
                            < self.max_value
                        )
                mask = self.custom_fn(mask, info=results)
                if key + "_mask" not in results:
                    results[key + "_mask"] = {}
                if sequence_idx not in results[key + "_mask"]:
                    results[key + "_mask"][sequence_idx] = mask.bool()
                else:
                    results[key + "_mask"][sequence_idx] = (
                        results[key + "_mask"][sequence_idx].bool() & mask.bool()
                    )
            results["mask_fields"].add(key + "_mask")
        return results

    def __repr__(self):
        return (
            self.__class__.__name__
            + f"(min_value={self.min_value}, max_value={ self.max_value})"
        )