chendl's picture
Add application file
0b7b08a
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import os
import random
import uuid
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader as torchDataLoader
from torch.utils.data.dataloader import default_collate
from .samplers import YoloBatchSampler
def get_yolox_datadir():
"""
get dataset dir of YOLOX. If environment variable named `YOLOX_DATADIR` is set,
this function will return value of the environment variable. Otherwise, use data
"""
yolox_datadir = os.getenv("YOLOX_DATADIR", None)
if yolox_datadir is None:
import yolox
yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
yolox_datadir = os.path.join(yolox_path, "datasets")
return yolox_datadir
class DataLoader(torchDataLoader):
"""
Lightnet dataloader that enables on the fly resizing of the images.
See :class:`torch.utils.data.DataLoader` for more information on the arguments.
Check more on the following website:
https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__initialized = False
shuffle = False
batch_sampler = None
if len(args) > 5:
shuffle = args[2]
sampler = args[3]
batch_sampler = args[4]
elif len(args) > 4:
shuffle = args[2]
sampler = args[3]
if "batch_sampler" in kwargs:
batch_sampler = kwargs["batch_sampler"]
elif len(args) > 3:
shuffle = args[2]
if "sampler" in kwargs:
sampler = kwargs["sampler"]
if "batch_sampler" in kwargs:
batch_sampler = kwargs["batch_sampler"]
else:
if "shuffle" in kwargs:
shuffle = kwargs["shuffle"]
if "sampler" in kwargs:
sampler = kwargs["sampler"]
if "batch_sampler" in kwargs:
batch_sampler = kwargs["batch_sampler"]
# Use custom BatchSampler
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
# sampler = torch.utils.data.DistributedSampler(self.dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
batch_sampler = YoloBatchSampler(
sampler,
self.batch_size,
self.drop_last,
input_dimension=self.dataset.input_dim,
)
# batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =
self.batch_sampler = batch_sampler
self.__initialized = True
def close_mosaic(self):
self.batch_sampler.mosaic = False
def list_collate(batch):
"""
Function that collates lists or tuples together into one list (of lists/tuples).
Use this as the collate function in a Dataloader, if you want to have a list of
items as an output, as opposed to tensors (eg. Brambox.boxes).
"""
items = list(zip(*batch))
for i in range(len(items)):
if isinstance(items[i][0], (list, tuple)):
items[i] = list(items[i])
else:
items[i] = default_collate(items[i])
return items
def worker_init_reset_seed(worker_id):
seed = uuid.uuid4().int % 2**32
random.seed(seed)
torch.set_rng_state(torch.manual_seed(seed).get_state())
np.random.seed(seed)