File size: 1,436 Bytes
07e1105 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torchvision
from utils.dataset import folders
from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip
class Data_Loader(object):
"""Dataset class for IQA databases"""
def __init__(self, config, path, img_indx, istrain=True):
self.batch_size = config.batch_size
self.istrain = istrain
dataset = config.dataset
patch_size = config.patch_size
# Train transforms
if istrain:
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()])
else:
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
if dataset == 'livec':
self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms)
elif dataset == 'koniq10k':
self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms)
elif dataset == 'bid':
self.data = folders.BID(root=path, index=img_indx, transform=transforms)
elif dataset == 'spaq':
self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms)
else:
raise Exception("Only support livec, koniq10k, bid, spaq.")
def get_data(self):
dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8)
return dataloader |