Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import bisect | |
import copy | |
import os | |
import random | |
from abc import ABCMeta, abstractmethod | |
from functools import partial, wraps | |
from multiprocessing.pool import ThreadPool | |
import psutil | |
from loguru import logger | |
from tqdm import tqdm | |
import numpy as np | |
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset | |
from torch.utils.data.dataset import Dataset as torchDataset | |
class ConcatDataset(torchConcatDataset): | |
def __init__(self, datasets): | |
super(ConcatDataset, self).__init__(datasets) | |
if hasattr(self.datasets[0], "input_dim"): | |
self._input_dim = self.datasets[0].input_dim | |
self.input_dim = self.datasets[0].input_dim | |
def pull_item(self, idx): | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError( | |
"absolute value of index should not exceed dataset length" | |
) | |
idx = len(self) + idx | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
return self.datasets[dataset_idx].pull_item(sample_idx) | |
class MixConcatDataset(torchConcatDataset): | |
def __init__(self, datasets): | |
super(MixConcatDataset, self).__init__(datasets) | |
if hasattr(self.datasets[0], "input_dim"): | |
self._input_dim = self.datasets[0].input_dim | |
self.input_dim = self.datasets[0].input_dim | |
def __getitem__(self, index): | |
if not isinstance(index, int): | |
idx = index[1] | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError( | |
"absolute value of index should not exceed dataset length" | |
) | |
idx = len(self) + idx | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
if not isinstance(index, int): | |
index = (index[0], sample_idx, index[2]) | |
return self.datasets[dataset_idx][index] | |
class Dataset(torchDataset): | |
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`, | |
that enables on the fly resizing of the ``input_dim``. | |
Args: | |
input_dimension (tuple): (width,height) tuple with default dimensions of the network | |
""" | |
def __init__(self, input_dimension, mosaic=True): | |
super().__init__() | |
self.__input_dim = input_dimension[:2] | |
self.enable_mosaic = mosaic | |
def input_dim(self): | |
""" | |
Dimension that can be used by transforms to set the correct image size, etc. | |
This allows transforms to have a single source of truth | |
for the input dimension of the network. | |
Return: | |
list: Tuple containing the current width,height | |
""" | |
if hasattr(self, "_input_dim"): | |
return self._input_dim | |
return self.__input_dim | |
def mosaic_getitem(getitem_fn): | |
""" | |
Decorator method that needs to be used around the ``__getitem__`` method. |br| | |
This decorator enables the closing mosaic | |
Example: | |
>>> class CustomSet(ln.data.Dataset): | |
... def __len__(self): | |
... return 10 | |
... @ln.data.Dataset.mosaic_getitem | |
... def __getitem__(self, index): | |
... return self.enable_mosaic | |
""" | |
def wrapper(self, index): | |
if not isinstance(index, int): | |
self.enable_mosaic = index[0] | |
index = index[1] | |
ret_val = getitem_fn(self, index) | |
return ret_val | |
return wrapper | |
class CacheDataset(Dataset, metaclass=ABCMeta): | |
""" This class is a subclass of the base :class:`yolox.data.datasets.Dataset`, | |
that enables cache images to ram or disk. | |
Args: | |
input_dimension (tuple): (width,height) tuple with default dimensions of the network | |
num_imgs (int): datset size | |
data_dir (str): the root directory of the dataset, e.g. `/path/to/COCO`. | |
cache_dir_name (str): the name of the directory to cache to disk, | |
e.g. `"custom_cache"`. The files cached to disk will be saved | |
under `/path/to/COCO/custom_cache`. | |
path_filename (str): a list of paths to the data relative to the `data_dir`, | |
e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, | |
then `path_filename = ['train/1.jpg', ' train/2.jpg']`. | |
cache (bool): whether to cache the images to ram or disk. | |
cache_type (str): the type of cache, | |
"ram" : Caching imgs to ram for fast training. | |
"disk": Caching imgs to disk for fast training. | |
""" | |
def __init__( | |
self, | |
input_dimension, | |
num_imgs=None, | |
data_dir=None, | |
cache_dir_name=None, | |
path_filename=None, | |
cache=False, | |
cache_type="ram", | |
): | |
super().__init__(input_dimension) | |
self.cache = cache | |
self.cache_type = cache_type | |
if self.cache and self.cache_type == "disk": | |
self.cache_dir = os.path.join(data_dir, cache_dir_name) | |
self.path_filename = path_filename | |
if self.cache and self.cache_type == "ram": | |
self.imgs = None | |
if self.cache: | |
self.cache_images( | |
num_imgs=num_imgs, | |
data_dir=data_dir, | |
cache_dir_name=cache_dir_name, | |
path_filename=path_filename, | |
) | |
def __del__(self): | |
if self.cache and self.cache_type == "ram": | |
del self.imgs | |
def read_img(self, index): | |
""" | |
Given index, return the corresponding image | |
Args: | |
index (int): image index | |
""" | |
raise NotImplementedError | |
def cache_images( | |
self, | |
num_imgs=None, | |
data_dir=None, | |
cache_dir_name=None, | |
path_filename=None, | |
): | |
assert num_imgs is not None, "num_imgs must be specified as the size of the dataset" | |
if self.cache_type == "disk": | |
assert (data_dir and cache_dir_name and path_filename) is not None, \ | |
"data_dir, cache_name and path_filename must be specified if cache_type is disk" | |
self.path_filename = path_filename | |
mem = psutil.virtual_memory() | |
mem_required = self.cal_cache_occupy(num_imgs) | |
gb = 1 << 30 | |
if self.cache_type == "ram": | |
if mem_required > mem.available: | |
self.cache = False | |
else: | |
logger.info( | |
f"{mem_required / gb:.1f}GB RAM required, " | |
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, " | |
f"Since the first thing we do is cache, " | |
f"there is no guarantee that the remaining memory space is sufficient" | |
) | |
if self.cache and self.imgs is None: | |
if self.cache_type == 'ram': | |
self.imgs = [None] * num_imgs | |
logger.info("You are using cached images in RAM to accelerate training!") | |
else: # 'disk' | |
if not os.path.exists(self.cache_dir): | |
os.mkdir(self.cache_dir) | |
logger.warning( | |
f"\n*******************************************************************\n" | |
f"You are using cached images in DISK to accelerate training.\n" | |
f"This requires large DISK space.\n" | |
f"Make sure you have {mem_required / gb:.1f} " | |
f"available DISK space for training your dataset.\n" | |
f"*******************************************************************\\n" | |
) | |
else: | |
logger.info(f"Found disk cache at {self.cache_dir}") | |
return | |
logger.info( | |
"Caching images...\n" | |
"This might take some time for your dataset" | |
) | |
num_threads = min(8, max(1, os.cpu_count() - 1)) | |
b = 0 | |
load_imgs = ThreadPool(num_threads).imap( | |
partial(self.read_img, use_cache=False), | |
range(num_imgs) | |
) | |
pbar = tqdm(enumerate(load_imgs), total=num_imgs) | |
for i, x in pbar: # x = self.read_img(self, i, use_cache=False) | |
if self.cache_type == 'ram': | |
self.imgs[i] = x | |
else: # 'disk' | |
cache_filename = f'{self.path_filename[i].split(".")[0]}.npy' | |
cache_path_filename = os.path.join(self.cache_dir, cache_filename) | |
os.makedirs(os.path.dirname(cache_path_filename), exist_ok=True) | |
np.save(cache_path_filename, x) | |
b += x.nbytes | |
pbar.desc = \ | |
f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache_type})' | |
pbar.close() | |
def cal_cache_occupy(self, num_imgs): | |
cache_bytes = 0 | |
num_samples = min(num_imgs, 32) | |
for _ in range(num_samples): | |
img = self.read_img(index=random.randint(0, num_imgs - 1), use_cache=False) | |
cache_bytes += img.nbytes | |
mem_required = cache_bytes * num_imgs / num_samples | |
return mem_required | |
def cache_read_img(use_cache=True): | |
def decorator(read_img_fn): | |
""" | |
Decorate the read_img function to cache the image | |
Args: | |
read_img_fn: read_img function | |
use_cache (bool, optional): For the decorated read_img function, | |
whether to read the image from cache. | |
Defaults to True. | |
""" | |
def wrapper(self, index, use_cache=use_cache): | |
cache = self.cache and use_cache | |
if cache: | |
if self.cache_type == "ram": | |
img = self.imgs[index] | |
img = copy.deepcopy(img) | |
elif self.cache_type == "disk": | |
img = np.load( | |
os.path.join( | |
self.cache_dir, f"{self.path_filename[index].split('.')[0]}.npy")) | |
else: | |
raise ValueError(f"Unknown cache type: {self.cache_type}") | |
else: | |
img = read_img_fn(self, index) | |
return img | |
return wrapper | |
return decorator | |