#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. import os import random import uuid import numpy as np import torch from torch.utils.data.dataloader import DataLoader as torchDataLoader from torch.utils.data.dataloader import default_collate from .samplers import YoloBatchSampler def get_yolox_datadir(): """ get dataset dir of YOLOX. If environment variable named `YOLOX_DATADIR` is set, this function will return value of the environment variable. Otherwise, use data """ yolox_datadir = os.getenv("YOLOX_DATADIR", None) if yolox_datadir is None: import yolox yolox_path = os.path.dirname(os.path.dirname(yolox.__file__)) yolox_datadir = os.path.join(yolox_path, "datasets") return yolox_datadir class DataLoader(torchDataLoader): """ Lightnet dataloader that enables on the fly resizing of the images. See :class:`torch.utils.data.DataLoader` for more information on the arguments. Check more on the following website: https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__initialized = False shuffle = False batch_sampler = None if len(args) > 5: shuffle = args[2] sampler = args[3] batch_sampler = args[4] elif len(args) > 4: shuffle = args[2] sampler = args[3] if "batch_sampler" in kwargs: batch_sampler = kwargs["batch_sampler"] elif len(args) > 3: shuffle = args[2] if "sampler" in kwargs: sampler = kwargs["sampler"] if "batch_sampler" in kwargs: batch_sampler = kwargs["batch_sampler"] else: if "shuffle" in kwargs: shuffle = kwargs["shuffle"] if "sampler" in kwargs: sampler = kwargs["sampler"] if "batch_sampler" in kwargs: batch_sampler = kwargs["batch_sampler"] # Use custom BatchSampler if batch_sampler is None: if sampler is None: if shuffle: sampler = torch.utils.data.sampler.RandomSampler(self.dataset) # sampler = torch.utils.data.DistributedSampler(self.dataset) else: sampler = torch.utils.data.sampler.SequentialSampler(self.dataset) batch_sampler = YoloBatchSampler( sampler, self.batch_size, self.drop_last, input_dimension=self.dataset.input_dim, ) # batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations = self.batch_sampler = batch_sampler self.__initialized = True def close_mosaic(self): self.batch_sampler.mosaic = False def list_collate(batch): """ Function that collates lists or tuples together into one list (of lists/tuples). Use this as the collate function in a Dataloader, if you want to have a list of items as an output, as opposed to tensors (eg. Brambox.boxes). """ items = list(zip(*batch)) for i in range(len(items)): if isinstance(items[i][0], (list, tuple)): items[i] = list(items[i]) else: items[i] = default_collate(items[i]) return items def worker_init_reset_seed(worker_id): seed = uuid.uuid4().int % 2**32 random.seed(seed) torch.set_rng_state(torch.manual_seed(seed).get_state()) np.random.seed(seed)