LINC-BIT's picture
Upload 1912 files
b84549f verified
from curses import raw
from .data_augment import TrainTransform, ValTransform
from .datasets.coco import COCODataset
from .datasets.mm_coco import MM_COCODataset
from .datasets.mosaicdetection import MosaicDetection
from utils.common.others import HiddenPrints
import os
import json
from tqdm import tqdm
from utils.common.log import logger
from .norm_categories_index import ensure_index_start_from_1_and_successive
def get_default_yolox_coco_dataset(data_dir, json_file_path, img_size=416, train=True):
logger.info(f'[get yolox dataset] "{json_file_path}"')
if train:
with HiddenPrints():
dataset = COCODataset(
data_dir=data_dir,
json_file=json_file_path,
name='',
img_size=(img_size, img_size),
preproc=TrainTransform(
max_labels=50,
flip_prob=0.5,
hsv_prob=1.0),
cache=False,
)
# dataset = COCODataset(
# data_dir=data_dir,
# json_file=json_file_path,
# name='',
# img_size=(img_size, img_size),
# preproc=ValTransform(legacy=False),
# )
dataset = MosaicDetection(
dataset,
mosaic=True,
img_size=(img_size, img_size),
preproc=TrainTransform(
max_labels=120,
flip_prob=0.5,
hsv_prob=1.0),
degrees=10.0,
translate=0.1,
mosaic_scale=(0.1, 2),
mixup_scale=(0.5, 1.5),
shear=2.0,
enable_mixup=True,
mosaic_prob=1.0,
mixup_prob=1.0,
only_return_xy=True
)
else:
with HiddenPrints():
dataset = COCODataset(
data_dir=data_dir,
json_file=json_file_path,
name='',
img_size=(img_size, img_size),
preproc=ValTransform(legacy=False),
)
# print(json_file_path, len(dataset))
return dataset
def get_yolox_coco_dataset_with_caption(data_dir, json_file_path, img_size=416, transform=None, train=True, classes=None):
logger.info(f'[get yolox dataset] "{json_file_path}"')
if train:
with HiddenPrints():
dataset = COCODataset(
data_dir=data_dir,
json_file=json_file_path,
name='',
img_size=(img_size, img_size),
preproc=TrainTransform(
max_labels=50,
flip_prob=0.5,
hsv_prob=1.0),
cache=False,
)
# dataset = COCODataset(
# data_dir=data_dir,
# json_file=json_file_path,
# name='',
# img_size=(img_size, img_size),
# preproc=ValTransform(legacy=False),
# )
coco = dataset.coco
dataset = MosaicDetection(
dataset,
mosaic=True,
img_size=(img_size, img_size),
preproc=TrainTransform(
max_labels=120,
flip_prob=0.5,
hsv_prob=1.0),
degrees=10.0,
translate=0.1,
mosaic_scale=(0.1, 2),
mixup_scale=(0.5, 1.5),
shear=2.0,
enable_mixup=True,
mosaic_prob=1.0,
mixup_prob=1.0,
only_return_xy=True
)
dataset = MM_COCODataset(dataset, transform=transform, split='train', coco=coco, classes=classes)
else:
with HiddenPrints():
dataset = COCODataset(
data_dir=data_dir,
json_file=json_file_path,
name='',
img_size=(img_size, img_size),
preproc=ValTransform(legacy=False),
)
dataset = MM_COCODataset(dataset, transform=transform, split='val', coco=dataset.coco, classes=classes)
# print(json_file_path, len(dataset))
return dataset
import hashlib
def _hash(o):
if isinstance(o, list):
o = sorted(o)
elif isinstance(o, dict):
o = {k: o[k] for k in sorted(o)}
elif isinstance(o, set):
o = sorted(list(o))
# else:
# print(type(o))
obj = hashlib.md5()
obj.update(str(o).encode('utf-8'))
return obj.hexdigest()
DEBUG = True
def remap_dataset(json_file_path, ignore_classes, category_idx_map):
# k and v in category_idx_map indicates the index of categories, not 'id' of categories
ignore_classes = sorted(list(ignore_classes))
# print(ignore_classes, category_idx_map)
if len(ignore_classes) == 0 and category_idx_map is None:
return json_file_path
# hash_str = '_'.join(ignore_classes) + str(category_idx_map)
hash_str = _hash(f'yolox_dataset_cache_{_hash(ignore_classes)}_{_hash(category_idx_map)}')
cached_json_file_path = f'{json_file_path}.{hash(hash_str)}'
# TODO:
if os.path.exists(cached_json_file_path):
if DEBUG:
os.remove(cached_json_file_path)
else:
logger.info(f'get cached dataset in {cached_json_file_path}')
return cached_json_file_path
with open(json_file_path, 'r') as f:
raw_ann = json.load(f)
id_to_idx_map = {c['id']: i for i, c in enumerate(raw_ann['categories'])}
ignore_classes_id = [c['id'] for c in raw_ann['categories'] if c['name'] in ignore_classes]
raw_ann['categories'] = [c for c in raw_ann['categories'] if c['id'] not in ignore_classes_id]
raw_ann['annotations'] = [ann for ann in raw_ann['annotations'] if ann['category_id'] not in ignore_classes_id]
ann_img_map = {ann['image_id']: 1 for ann in raw_ann['annotations']}
raw_ann['images'] = [img for img in raw_ann['images'] if img['id'] in ann_img_map.keys()]
# print(category_idx_map, id_to_idx_map)
# NOTE: category idx starts from 0 or 1? 1
# NOTE: reshuffle "categories"
new_categories = [{"id": i, "name": f"dummy-{i}"} for i in range(int(os.environ['_ZQL_NUMC']))]
for c in raw_ann['categories']:
# print(c)
# print(id_to_idx_map, c['id'], category_idx_map)
new_idx = category_idx_map[id_to_idx_map[c['id']]]
new_categories[new_idx] = c
c['id'] = new_idx
raw_ann['categories'] = new_categories
for ann in raw_ann['annotations']:
ann['category_id'] = category_idx_map[id_to_idx_map[ann['category_id']]]
if 'segmentation' in ann:
del ann['segmentation']
with open(cached_json_file_path, 'w') as f:
json.dump(raw_ann, f)
return cached_json_file_path
def coco_split(ann_json_file_path, ratio=0.8):
if os.path.exists(ann_json_file_path + f'.{ratio}.split1') and not DEBUG:
return ann_json_file_path + f'.{ratio}.split1', ann_json_file_path + f'.{ratio}.split2'
with open(ann_json_file_path, 'r') as f:
raw_ann = json.load(f)
import copy
import torch
res_ann1, res_ann2 = copy.deepcopy(raw_ann), copy.deepcopy(raw_ann)
images = raw_ann['images']
cache_images_path = ann_json_file_path + '.tmp-cached-shuffled-images'
if True:
import random
random.shuffle(images)
torch.save(images, cache_images_path)
else:
images = torch.load(cache_images_path)
images1, images2 = images[0: int(len(images) * ratio)], images[int(len(images) * ratio): ]
images1_id, images2_id = {i['id']: 0 for i in images1}, {i['id']: 0 for i in images2}
ann1 = [ann for ann in raw_ann['annotations'] if ann['image_id'] in images1_id.keys()]
ann2 = [ann for ann in raw_ann['annotations'] if ann['image_id'] in images2_id.keys()]
res_ann1['images'] = images1
res_ann1['annotations'] = ann1
res_ann2['images'] = images2
res_ann2['annotations'] = ann2
from utils.common.data_record import write_json
write_json(ann_json_file_path + f'.{ratio}.split1', res_ann1, indent=0, backup=False)
write_json(ann_json_file_path + f'.{ratio}.split2', res_ann2, indent=0, backup=False)
return ann_json_file_path + f'.{ratio}.split1', ann_json_file_path + f'.{ratio}.split2'
def coco_train_val_test_split(ann_json_file_path, split):
train_ann_p, test_ann_p = coco_split(ann_json_file_path)
if split == 'test':
return test_ann_p
train_ann_p, val_ann_p = coco_split(train_ann_p)
return train_ann_p if split == 'train' else val_ann_p
def coco_train_val_split(train_ann_p, split):
train_ann_p, val_ann_p = coco_split(train_ann_p)
return train_ann_p if split == 'train' else val_ann_p
def visualize_coco_dataset(dataset, num_images, res_save_p, cxcy):
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
def draw_bbox(img, bbox, label, f):
# if f:
# img = np.uint8(img.transpose(1, 2, 0))
img = Image.fromarray(img)
draw = ImageDraw.Draw(img)
draw.rectangle(bbox, outline=(255, 0, 0), width=6)
draw.text((bbox[0], bbox[1]), label)
return np.array(img)
d = dataset.dataset
if d.__class__.__name__ == 'MosaicDetection':
d = d._dataset
class_ids = d.class_ids # category_id
def get_cname(label):
return d.coco.loadCats(class_ids[int(label)])[0]['name']
def cxcywh2xyxy(bbox):
cx, cy, w, h = bbox
x1, y1 = cx - w/2, cy - h/2
x2, y2 = cx + w/2, cy + h/2
return x1, y1, x2, y2
xs = []
import random
for image_i in range(num_images):
x, y = dataset[random.randint(0, len(dataset) - 1)][:2]
x = np.uint8(x.transpose(1, 2, 0))
for label_i, label_info in enumerate(y):
if sum(label_info[1:]) == 0: # pad label
break
label, bbox = label_info[0], label_info[1:]
if cxcy:
bbox = cxcywh2xyxy(bbox)
x = draw_bbox(x, bbox, str(label) + '-' + get_cname(label), label_i == 0)
# print(x.shape)
xs += [x]
xs = [ToTensor()(x) for x in xs]
grid = make_grid(xs, normalize=True, nrow=2)
plt.axis('off')
img = grid.permute(1, 2, 0).numpy()
plt.imshow(img)
plt.savefig(res_save_p, dpi=300)
plt.clf()