GMC-IQA / utils /dataset /data_loader.py
Zevin2023's picture
MoC-IQA
07e1105
import torch
import torchvision
from utils.dataset import folders
from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip
class Data_Loader(object):
"""Dataset class for IQA databases"""
def __init__(self, config, path, img_indx, istrain=True):
self.batch_size = config.batch_size
self.istrain = istrain
dataset = config.dataset
patch_size = config.patch_size
# Train transforms
if istrain:
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()])
else:
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
if dataset == 'livec':
self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms)
elif dataset == 'koniq10k':
self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms)
elif dataset == 'bid':
self.data = folders.BID(root=path, index=img_indx, transform=transforms)
elif dataset == 'spaq':
self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms)
else:
raise Exception("Only support livec, koniq10k, bid, spaq.")
def get_data(self):
dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8)
return dataloader