File size: 10,522 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
from curses import raw
from .data_augment import TrainTransform, ValTransform
from .datasets.coco import COCODataset
from .datasets.mm_coco import MM_COCODataset
from .datasets.mosaicdetection import MosaicDetection
from utils.common.others import HiddenPrints
import os
import json
from tqdm import tqdm 
from utils.common.log import logger

from .norm_categories_index import ensure_index_start_from_1_and_successive


def get_default_yolox_coco_dataset(data_dir, json_file_path, img_size=416, train=True):
    logger.info(f'[get yolox dataset] "{json_file_path}"')

    if train:
        with HiddenPrints():
            dataset = COCODataset(
                data_dir=data_dir,
                json_file=json_file_path,
                name='',
                img_size=(img_size, img_size),
                preproc=TrainTransform(
                    max_labels=50,
                    flip_prob=0.5,
                    hsv_prob=1.0),
                cache=False,
            )
            # dataset = COCODataset(
            #     data_dir=data_dir,
            #     json_file=json_file_path,
            #     name='',
            #     img_size=(img_size, img_size),
            #     preproc=ValTransform(legacy=False),
            # )
            
        dataset = MosaicDetection(
            dataset,
            mosaic=True,
            img_size=(img_size, img_size),
            preproc=TrainTransform(
                max_labels=120,
                flip_prob=0.5,
                hsv_prob=1.0),
            degrees=10.0,
            translate=0.1,
            mosaic_scale=(0.1, 2),
            mixup_scale=(0.5, 1.5),
            shear=2.0,
            enable_mixup=True,
            mosaic_prob=1.0,
            mixup_prob=1.0,
            only_return_xy=True
        )
        
    else:
        with HiddenPrints():
            dataset = COCODataset(
                data_dir=data_dir,
                json_file=json_file_path,
                name='',
                img_size=(img_size, img_size),
                preproc=ValTransform(legacy=False),
            )

    # print(json_file_path, len(dataset))
            
    return dataset

def get_yolox_coco_dataset_with_caption(data_dir, json_file_path, img_size=416, transform=None, train=True, classes=None):
    logger.info(f'[get yolox dataset] "{json_file_path}"')

    if train:
        with HiddenPrints():
            dataset = COCODataset(
                data_dir=data_dir,
                json_file=json_file_path,
                name='',
                img_size=(img_size, img_size),
                preproc=TrainTransform(
                    max_labels=50,
                    flip_prob=0.5,
                    hsv_prob=1.0),
                cache=False,
            )
            # dataset = COCODataset(
            #     data_dir=data_dir,
            #     json_file=json_file_path,
            #     name='',
            #     img_size=(img_size, img_size),
            #     preproc=ValTransform(legacy=False),
            # )
        coco = dataset.coco
        dataset = MosaicDetection(
            dataset,
            mosaic=True,
            img_size=(img_size, img_size),
            preproc=TrainTransform(
                max_labels=120,
                flip_prob=0.5,
                hsv_prob=1.0),
            degrees=10.0,
            translate=0.1,
            mosaic_scale=(0.1, 2),
            mixup_scale=(0.5, 1.5),
            shear=2.0,
            enable_mixup=True,
            mosaic_prob=1.0,
            mixup_prob=1.0,
            only_return_xy=True
        )
        dataset = MM_COCODataset(dataset, transform=transform, split='train', coco=coco, classes=classes)
    else:
        with HiddenPrints():
            dataset = COCODataset(
                data_dir=data_dir,
                json_file=json_file_path,
                name='',
                img_size=(img_size, img_size),
                preproc=ValTransform(legacy=False),
            )
        dataset = MM_COCODataset(dataset, transform=transform, split='val', coco=dataset.coco, classes=classes)
    # print(json_file_path, len(dataset))
            
    return dataset

import hashlib

def _hash(o):
    if isinstance(o, list):
        o = sorted(o)
    elif isinstance(o, dict):
        o = {k: o[k] for k in sorted(o)}
    elif isinstance(o, set):
        o = sorted(list(o))
    # else:
    #     print(type(o))
    
    obj = hashlib.md5()
    obj.update(str(o).encode('utf-8'))
    return obj.hexdigest()


DEBUG = True


def remap_dataset(json_file_path, ignore_classes, category_idx_map):
    # k and v in category_idx_map indicates the index of categories, not 'id' of categories
    ignore_classes = sorted(list(ignore_classes))
    # print(ignore_classes, category_idx_map)

    if len(ignore_classes) == 0 and category_idx_map is None:
        return json_file_path
    
    # hash_str = '_'.join(ignore_classes) + str(category_idx_map)
    hash_str = _hash(f'yolox_dataset_cache_{_hash(ignore_classes)}_{_hash(category_idx_map)}')
    cached_json_file_path = f'{json_file_path}.{hash(hash_str)}'
    
    # TODO:
    if os.path.exists(cached_json_file_path):
        if DEBUG:
            os.remove(cached_json_file_path)
        else:
            logger.info(f'get cached dataset in {cached_json_file_path}')
            return cached_json_file_path
    
    with open(json_file_path, 'r') as f:
        raw_ann = json.load(f)
    id_to_idx_map = {c['id']: i for i, c in enumerate(raw_ann['categories'])}
        
    ignore_classes_id = [c['id'] for c in raw_ann['categories'] if c['name'] in ignore_classes]
    raw_ann['categories'] = [c for c in raw_ann['categories'] if c['id'] not in ignore_classes_id]
    raw_ann['annotations'] = [ann for ann in raw_ann['annotations'] if ann['category_id'] not in ignore_classes_id]
    ann_img_map = {ann['image_id']: 1 for ann in raw_ann['annotations']}
    raw_ann['images'] = [img for img in raw_ann['images'] if img['id'] in ann_img_map.keys()]
    
    # print(category_idx_map, id_to_idx_map)
    # NOTE: category idx starts from 0 or 1? 1
    # NOTE: reshuffle "categories"
    new_categories = [{"id": i, "name": f"dummy-{i}"} for i in range(int(os.environ['_ZQL_NUMC']))]
    for c in raw_ann['categories']:
        # print(c)
        # print(id_to_idx_map, c['id'], category_idx_map)
        new_idx = category_idx_map[id_to_idx_map[c['id']]]
        new_categories[new_idx] = c
        c['id'] = new_idx
    raw_ann['categories'] = new_categories
    for ann in raw_ann['annotations']:
        ann['category_id'] = category_idx_map[id_to_idx_map[ann['category_id']]]
        if 'segmentation' in ann:
            del ann['segmentation']
    
    with open(cached_json_file_path, 'w') as f:
        json.dump(raw_ann, f)
        
    return cached_json_file_path


def coco_split(ann_json_file_path, ratio=0.8):
    if os.path.exists(ann_json_file_path + f'.{ratio}.split1') and not DEBUG:
        return ann_json_file_path + f'.{ratio}.split1', ann_json_file_path + f'.{ratio}.split2'
    
    with open(ann_json_file_path, 'r') as f:
        raw_ann = json.load(f)

    import copy 
    import torch 
    res_ann1, res_ann2 = copy.deepcopy(raw_ann), copy.deepcopy(raw_ann)

    images = raw_ann['images']

    cache_images_path = ann_json_file_path + '.tmp-cached-shuffled-images'
    if True:
        import random
        random.shuffle(images)
        torch.save(images, cache_images_path)
    else:
        images = torch.load(cache_images_path)

    images1, images2 = images[0: int(len(images) * ratio)], images[int(len(images) * ratio): ]
    images1_id, images2_id = {i['id']: 0 for i in images1}, {i['id']: 0 for i in images2}
    ann1 = [ann for ann in raw_ann['annotations'] if ann['image_id'] in images1_id.keys()]
    ann2 = [ann for ann in raw_ann['annotations'] if ann['image_id'] in images2_id.keys()]

    res_ann1['images'] = images1
    res_ann1['annotations'] = ann1
    res_ann2['images'] = images2
    res_ann2['annotations'] = ann2 

    from utils.common.data_record import write_json
    write_json(ann_json_file_path + f'.{ratio}.split1', res_ann1, indent=0, backup=False)
    write_json(ann_json_file_path + f'.{ratio}.split2', res_ann2, indent=0, backup=False)

    return ann_json_file_path + f'.{ratio}.split1', ann_json_file_path + f'.{ratio}.split2'


def coco_train_val_test_split(ann_json_file_path, split):
    train_ann_p, test_ann_p = coco_split(ann_json_file_path)
    if split == 'test':
        return test_ann_p
    train_ann_p, val_ann_p = coco_split(train_ann_p)
    return train_ann_p if split == 'train' else val_ann_p


def coco_train_val_split(train_ann_p, split):
    train_ann_p, val_ann_p = coco_split(train_ann_p)
    return train_ann_p if split == 'train' else val_ann_p


def visualize_coco_dataset(dataset, num_images, res_save_p, cxcy):
    from torchvision.transforms import ToTensor
    from torchvision.utils import make_grid
    from PIL import Image, ImageDraw
    import matplotlib.pyplot as plt
    import numpy as np
    
    def draw_bbox(img, bbox, label, f):
        # if f:
        #     img = np.uint8(img.transpose(1, 2, 0))
        img = Image.fromarray(img)
        draw = ImageDraw.Draw(img)
        draw.rectangle(bbox, outline=(255, 0, 0), width=6)
        draw.text((bbox[0], bbox[1]), label)
        return np.array(img)

    d = dataset.dataset
    if d.__class__.__name__ == 'MosaicDetection':
        d = d._dataset
    class_ids = d.class_ids # category_id
    def get_cname(label):
        return d.coco.loadCats(class_ids[int(label)])[0]['name']

    def cxcywh2xyxy(bbox):
        cx, cy, w, h = bbox
        x1, y1 = cx - w/2, cy - h/2
        x2, y2 = cx + w/2, cy + h/2
        return x1, y1, x2, y2

    xs = []
    import random
    for image_i in range(num_images):
        x, y = dataset[random.randint(0, len(dataset) - 1)][:2]
        x = np.uint8(x.transpose(1, 2, 0))

        for label_i, label_info in enumerate(y):
            if sum(label_info[1:]) == 0: # pad label
                break

            label, bbox = label_info[0], label_info[1:]

            if cxcy:
                bbox = cxcywh2xyxy(bbox)

            x = draw_bbox(x, bbox, str(label) + '-' + get_cname(label), label_i == 0)
        # print(x.shape)
        xs += [x]

    xs = [ToTensor()(x) for x in xs]
    grid = make_grid(xs, normalize=True, nrow=2)
    plt.axis('off')
    img = grid.permute(1, 2, 0).numpy()
    plt.imshow(img)
    plt.savefig(res_save_p, dpi=300)
    plt.clf()