File size: 11,323 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
from collections import defaultdict
import glob
import json
import os
from typing import Callable, Dict, List, Tuple

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageNet

from virtex.data import transforms as T


class ImageNetDataset(ImageNet):
    r"""
    Simple wrapper over torchvision's ImageNet dataset with a feature to support
    restricting dataset size for semi-supervised learning setup (data-efficiency
    ablations).

    We also handle image transform here instead of passing to super class.

    Parameters
    ----------
    data_root: str, optional (default = "datasets/imagenet")
        Path to the dataset root directory. This must contain directories
        ``train``, ``val`` with per-category sub-directories.
    split: str, optional (default = "train")
        Which split to read from. One of ``{"train", "val"}``.
    image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
        A list of transformations, from either `albumentations
        <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms`
        to be applied on the image.
    percentage: int, optional (default = 100)
        Percentage of dataset to keep. This dataset retains first K% of images
        per class to retain same class label distribution. This is 100% by
        default, and will be ignored if ``split`` is ``val``.
    """

    def __init__(
        self,
        data_root: str = "datasets/imagenet",
        split: str = "train",
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
        percentage: float = 100,
    ):
        super().__init__(data_root, split)
        assert percentage > 0, "Cannot load dataset with 0 percent original size."

        self.image_transform = image_transform

        # Super class has `imgs` list and `targets` list. Make a dict of
        # class ID to index of instances in these lists and pick first K%.
        if split == "train" and percentage < 100:
            label_to_indices: Dict[int, List[int]] = defaultdict(list)
            for index, target in enumerate(self.targets):
                label_to_indices[target].append(index)

            # Trim list of indices per label.
            for label in label_to_indices:
                retain = int(len(label_to_indices[label]) * (percentage / 100))
                label_to_indices[label] = label_to_indices[label][:retain]

            # Trim `self.imgs` and `self.targets` as per indices we have.
            retained_indices: List[int] = [
                index
                for indices_per_label in label_to_indices.values()
                for index in indices_per_label
            ]
            # Shorter dataset with size K% of original dataset, but almost same
            # class label distribution. super class will handle the rest.
            self.imgs = [self.imgs[i] for i in retained_indices]
            self.targets = [self.targets[i] for i in retained_indices]
            self.samples = self.imgs

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        image, label = super().__getitem__(idx)

        # Apply transformation to  image and convert to CHW format.
        image = self.image_transform(image=np.array(image))["image"]
        image = np.transpose(image, (2, 0, 1))
        return {
            "image": torch.tensor(image, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    @staticmethod
    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        return {
            "image": torch.stack([d["image"] for d in data], dim=0),
            "label": torch.stack([d["label"] for d in data], dim=0),
        }


class INaturalist2018Dataset(Dataset):
    r"""
    A dataset which provides image-label pairs from the iNaturalist 2018 dataset.

    Parameters
    ----------
    data_root: str, optional (default = "datasets/inaturalist")
        Path to the dataset root directory. This must contain images and
        annotations (``train2018``, ``val2018`` and ``annotations`` directories).
    split: str, optional (default = "train")
        Which split to read from. One of ``{"train", "val"}``.
    image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
        A list of transformations, from either `albumentations
        <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms`
        to be applied on the image.
    """

    def __init__(
        self,
        data_root: str = "datasets/inaturalist",
        split: str = "train",
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
    ):
        self.split = split
        self.image_transform = image_transform

        annotations = json.load(
            open(os.path.join(data_root, "annotations", f"{split}2018.json"))
        )
        # Make a list of image IDs to file paths.
        self.image_id_to_file_path = {
            ann["id"]: os.path.join(data_root, ann["file_name"])
            for ann in annotations["images"]
        }
        # For a list of instances: (image_id, category_id) tuples.
        self.instances = [
            (ann["image_id"], ann["category_id"])
            for ann in annotations["annotations"]
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_id, label = self.instances[idx]
        image_path = self.image_id_to_file_path[image_id]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    @staticmethod
    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        return {
            "image": torch.stack([d["image"] for d in data], dim=0),
            "label": torch.stack([d["label"] for d in data], dim=0),
        }


class VOC07ClassificationDataset(Dataset):
    r"""
    A dataset which provides image-label pairs from the PASCAL VOC 2007 dataset.

    Parameters
    ----------
    data_root: str, optional (default = "datasets/VOC2007")
        Path to the dataset root directory. This must contain directories
        ``Annotations``, ``ImageSets`` and ``JPEGImages``.
    split: str, optional (default = "trainval")
        Which split to read from. One of ``{"trainval", "test"}``.
    image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
        A list of transformations, from either `albumentations
        <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms`
        to be applied on the image.
    """

    def __init__(
        self,
        data_root: str = "datasets/VOC2007",
        split: str = "trainval",
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
    ):
        self.split = split
        self.image_transform = image_transform

        ann_paths = sorted(
            glob.glob(os.path.join(data_root, "ImageSets", "Main", f"*_{split}.txt"))
        )
        # A list like; ["aeroplane", "bicycle", "bird", ...]
        self.class_names = [
            os.path.basename(path).split("_")[0] for path in ann_paths
        ]

        # We will construct a map for image name to a list of
        # shape: (num_classes, ) and values as one of {-1, 0, 1}.
        # 1: present, -1: not present, 0: ignore.
        image_names_to_labels: Dict[str, torch.Tensor] = defaultdict(
            lambda: -torch.ones(len(self.class_names), dtype=torch.int32)
        )
        for cls_num, ann_path in enumerate(ann_paths):
            with open(ann_path, "r") as fopen:
                for line in fopen:
                    img_name, orig_label_str = line.strip().split()
                    orig_label = int(orig_label_str)

                    # In VOC data, -1 (not present): set to 0 as train target
                    # In VOC data, 0 (ignore): set to -1 as train target.
                    orig_label = (
                        0 if orig_label == -1 else -1 if orig_label == 0 else 1
                    )
                    image_names_to_labels[img_name][cls_num] = orig_label

        # Convert the dict to a list of tuples for easy indexing.
        # Replace image name with full image path.
        self.instances: List[Tuple[str, torch.Tensor]] = [
            (
                os.path.join(data_root, "JPEGImages", f"{image_name}.jpg"),
                label.tolist(),
            )
            for image_name, label in image_names_to_labels.items()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_path, label = self.instances[idx]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    @staticmethod
    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        return {
            "image": torch.stack([d["image"] for d in data], dim=0),
            "label": torch.stack([d["label"] for d in data], dim=0),
        }


class ImageDirectoryDataset(Dataset):
    r"""
    A dataset which reads images from any directory. This class is useful to
    run image captioning inference on our models with any arbitrary images.

    Parameters
    ----------
    data_root: str
        Path to a directory containing images.
    image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
        A list of transformations, from either `albumentations
        <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms`
        to be applied on the image.
    """

    def __init__(
        self, data_root: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM
    ):
        self.image_paths = glob.glob(os.path.join(data_root, "*"))
        self.image_transform = image_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        image_path = self.image_paths[idx]
        # Remove extension from image name to use as image_id.
        image_id = os.path.splitext(os.path.basename(image_path))[0]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        # Return image id as string so collate_fn does not cast to torch.tensor.
        return {"image_id": str(image_id), "image": torch.tensor(image)}