File size: 6,408 Bytes
f1dd031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
from collections import defaultdict
from typing import Dict, List, Optional, Union

from mmdet.datasets.transforms.frame_sampling import BaseFrameSample
from mmdet.registry import TRANSFORMS


@TRANSFORMS.register_module(force=True)
class MixUniformRefFrameSample(BaseFrameSample):
    """Uniformly sample reference frames.

    Args:
        num_ref_imgs (int): Number of reference frames to be sampled.
        frame_range (int | list[int]): Range of frames to be sampled around
            key frame. If int, the range is [-frame_range, frame_range].
            Defaults to 10.
        filter_key_img (bool): Whether to filter the key frame when
            sampling reference frames. Defaults to True.
        collect_video_keys (list[str]): The keys of video info to be
            collected.
    """

    def __init__(
        self,
        num_ref_imgs: int = 1,
        frame_range: Union[int, List[int]] = 10,
        filter_key_img: bool = True,
        collect_video_keys: List[str] = ["video_id", "video_length"],
    ):
        self.num_ref_imgs = num_ref_imgs
        self.filter_key_img = filter_key_img
        if isinstance(frame_range, int):
            assert frame_range >= 0, "frame_range can not be a negative value."
            frame_range = [-frame_range, frame_range]
        elif isinstance(frame_range, list):
            assert len(frame_range) == 2, "The length must be 2."
            assert frame_range[0] <= 0 and frame_range[1] >= 0
            for i in frame_range:
                assert isinstance(i, int), "Each element must be int."
        else:
            raise TypeError("The type of frame_range must be int or list.")
        self.frame_range = frame_range
        super().__init__(collect_video_keys=collect_video_keys)

    def sampling_frames(self, video_length: int, key_frame_id: int):
        """Sampling frames.

        Args:
            video_length (int): The length of the video.
            key_frame_id (int): The key frame id.

        Returns:
            list[int]: The sampled frame indices.
        """
        if video_length > 1:
            left = max(0, key_frame_id + self.frame_range[0])
            right = min(key_frame_id + self.frame_range[1], video_length - 1)
            frame_ids = list(range(0, video_length))

            valid_ids = frame_ids[left : right + 1]
            if self.filter_key_img and key_frame_id in valid_ids:
                valid_ids.remove(key_frame_id)
            assert (
                len(valid_ids) > 0
            ), "After filtering key frame, there are no valid frames"
            if len(valid_ids) < self.num_ref_imgs:
                valid_ids = valid_ids * self.num_ref_imgs
            ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs)
        else:
            ref_frame_ids = [key_frame_id] * self.num_ref_imgs

        sampled_frames_ids = [key_frame_id] + ref_frame_ids
        sampled_frames_ids = sorted(sampled_frames_ids)

        key_frames_ind = sampled_frames_ids.index(key_frame_id)
        key_frame_flags = [False] * len(sampled_frames_ids)
        key_frame_flags[key_frames_ind] = True
        return sampled_frames_ids, key_frame_flags

    def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
        """Transform the video information.

        Args:
            video_infos (dict): The whole video information.

        Returns:
            dict: The data information of the sampled frames.
        """

        if "video_length" not in video_infos:
            generated_video_info = {}
            key_frame_id = 0
            generated_video_info["video_id"] = video_infos["img_id"]
            generated_video_info["video_length"] = 1
            generated_video_info["key_frame_id"] = key_frame_id
            generated_video_info["images"] = [video_infos]
            (sampled_frames_ids, key_frame_flags) = self.sampling_frames(
                generated_video_info["video_length"], key_frame_id=key_frame_id
            )
            results = self.prepare_data(generated_video_info, sampled_frames_ids)
            results["key_frame_flags"] = key_frame_flags
            # results['is_image'] = True

        else:
            if "key_frame_id" in video_infos:
                key_frame_id = video_infos["key_frame_id"]
                assert isinstance(video_infos["key_frame_id"], int)
            else:
                key_frame_id = random.sample(
                    list(range(video_infos["video_length"])), 1
                )[0]

            (sampled_frames_ids, key_frame_flags) = self.sampling_frames(
                video_infos["video_length"], key_frame_id=key_frame_id
            )
            results = self.prepare_data(video_infos, sampled_frames_ids)
            results["key_frame_flags"] = key_frame_flags
            # results['is_image'] = False

        return results

    def prepare_data(
        self, video_infos: dict, sampled_inds: List[int]
    ) -> Dict[str, List]:
        """Prepare data for the subsequent pipeline.

        Args:
            video_infos (dict): The whole video information.
            sampled_inds (list[int]): The sampled frame indices.

        Returns:
            dict: The processed data information.
        """
        frames_anns = video_infos["images"]
        final_data_info = defaultdict(list)
        # for data in frames_anns:
        for index in sampled_inds:
            data = copy.deepcopy(frames_anns[index])
            # copy the info in video-level into img-level
            for key in self.collect_video_keys:
                if key == "video_length":
                    data["ori_video_length"] = video_infos[key]
                    data["video_length"] = len(sampled_inds)
                else:
                    data[key] = video_infos[key]
            # Collate data_list (list of dict to dict of list)
            for key, value in data.items():
                final_data_info[key].append(value)

        return final_data_info

    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f"(num_ref_imgs={self.num_ref_imgs}, "
        repr_str += f"frame_range={self.frame_range}, "
        repr_str += f"filter_key_img={self.filter_key_img}, "
        repr_str += f"collect_video_keys={self.collect_video_keys})"
        return repr_str