File size: 10,226 Bytes
1ed7deb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from pathlib import Path
from typing import Optional, List, Callable, Dict, Any, Union
import warnings

import PIL.Image as pil_image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms

from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import load_object_from_string
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
    Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor


class AnnotatedObjectsDataset(Dataset):
    def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
                 min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
                 crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
                 encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
                 no_object_classes: Optional[int] = None):
        self.data_path = data_path
        self.split = split
        self.keys = keys
        self.target_image_size = target_image_size
        self.min_object_area = min_object_area
        self.min_objects_per_image = min_objects_per_image
        self.max_objects_per_image = max_objects_per_image
        self.crop_method = crop_method
        self.random_flip = random_flip
        self.no_tokens = no_tokens
        self.use_group_parameter = use_group_parameter
        self.encode_crop = encode_crop

        self.annotations = None
        self.image_descriptions = None
        self.categories = None
        self.category_ids = None
        self.category_number = None
        self.image_ids = None
        self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
        self.paths = self.build_paths(self.data_path)
        self._conditional_builders = None
        self.category_allow_list = None
        if category_allow_list_target:
            allow_list = load_object_from_string(category_allow_list_target)
            self.category_allow_list = {name for name, _ in allow_list}
        self.category_mapping = {}
        if category_mapping_target:
            self.category_mapping = load_object_from_string(category_mapping_target)
        self.no_object_classes = no_object_classes

    def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
        top_level = Path(top_level)
        sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
        for path in sub_paths.values():
            if not path.exists():
                raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
        return sub_paths

    @staticmethod
    def load_image_from_disk(path: Path) -> Image:
        return pil_image.open(path).convert('RGB')

    @staticmethod
    def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
        transform_functions = []
        if crop_method == 'none':
            transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
        elif crop_method == 'center':
            transform_functions.extend([
                transforms.Resize(target_image_size),
                CenterCropReturnCoordinates(target_image_size)
            ])
        elif crop_method == 'random-1d':
            transform_functions.extend([
                transforms.Resize(target_image_size),
                RandomCrop1dReturnCoordinates(target_image_size)
            ])
        elif crop_method == 'random-2d':
            transform_functions.extend([
                Random2dCropReturnCoordinates(target_image_size),
                transforms.Resize(target_image_size)
            ])
        elif crop_method is None:
            return None
        else:
            raise ValueError(f'Received invalid crop method [{crop_method}].')
        if random_flip:
            transform_functions.append(RandomHorizontalFlipReturn())
        transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
        return transform_functions

    def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
        crop_bbox = None
        flipped = None
        for t in self.transform_functions:
            if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
                crop_bbox, x = t(x)
            elif isinstance(t, RandomHorizontalFlipReturn):
                flipped, x = t(x)
            else:
                x = t(x)
        return crop_bbox, flipped, x

    @property
    def no_classes(self) -> int:
        return self.no_object_classes if self.no_object_classes else len(self.categories)

    @property
    def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
        # cannot set this up in init because no_classes is only known after loading data in init of superclass
        if self._conditional_builders is None:
            self._conditional_builders = {
                'objects_center_points': ObjectsCenterPointsConditionalBuilder(
                    self.no_classes,
                    self.max_objects_per_image,
                    self.no_tokens,
                    self.encode_crop,
                    self.use_group_parameter,
                    getattr(self, 'use_additional_parameters', False)
                ),
                'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
                    self.no_classes,
                    self.max_objects_per_image,
                    self.no_tokens,
                    self.encode_crop,
                    self.use_group_parameter,
                    getattr(self, 'use_additional_parameters', False)
                )
            }
        return self._conditional_builders

    def filter_categories(self) -> None:
        if self.category_allow_list:
            self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
        if self.category_mapping:
            self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}

    def setup_category_id_and_number(self) -> None:
        self.category_ids = list(self.categories.keys())
        self.category_ids.sort()
        if '/m/01s55n' in self.category_ids:
            self.category_ids.remove('/m/01s55n')
            self.category_ids.append('/m/01s55n')
        self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
        if self.category_allow_list is not None and self.category_mapping is None \
                and len(self.category_ids) != len(self.category_allow_list):
            warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
                          'Make sure all names in category_allow_list exist.')

    def clean_up_annotations_and_image_descriptions(self) -> None:
        image_id_set = set(self.image_ids)
        self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
        self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}

    @staticmethod
    def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
                             min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
        filtered = {}
        for image_id, annotations in all_annotations.items():
            annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
            if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
                filtered[image_id] = annotations_with_min_area
        return filtered

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, n: int) -> Dict[str, Any]:
        image_id = self.get_image_id(n)
        sample = self.get_image_description(image_id)
        sample['annotations'] = self.get_annotation(image_id)

        if 'image' in self.keys:
            sample['image_path'] = str(self.get_image_path(image_id))
            sample['image'] = self.load_image_from_disk(sample['image_path'])
            sample['image'] = convert_pil_to_tensor(sample['image'])
            sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
            sample['image'] = sample['image'].permute(1, 2, 0)

        for conditional, builder in self.conditional_builders.items():
            if conditional in self.keys:
                sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])

        if self.keys:
            # only return specified keys
            sample = {key: sample[key] for key in self.keys}
        return sample

    def get_image_id(self, no: int) -> str:
        return self.image_ids[no]

    def get_annotation(self, image_id: str) -> str:
        return self.annotations[image_id]

    def get_textual_label_for_category_id(self, category_id: str) -> str:
        return self.categories[category_id].name

    def get_textual_label_for_category_no(self, category_no: int) -> str:
        return self.categories[self.get_category_id(category_no)].name

    def get_category_number(self, category_id: str) -> int:
        return self.category_number[category_id]

    def get_category_id(self, category_no: int) -> str:
        return self.category_ids[category_no]

    def get_image_description(self, image_id: str) -> Dict[str, Any]:
        raise NotImplementedError()

    def get_path_structure(self):
        raise NotImplementedError

    def get_image_path(self, image_id: str) -> Path:
        raise NotImplementedError