|
import torch |
|
import torchvision |
|
|
|
from utils.dataset import folders |
|
from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip |
|
|
|
class Data_Loader(object): |
|
"""Dataset class for IQA databases""" |
|
|
|
def __init__(self, config, path, img_indx, istrain=True): |
|
|
|
self.batch_size = config.batch_size |
|
self.istrain = istrain |
|
dataset = config.dataset |
|
patch_size = config.patch_size |
|
|
|
|
|
if istrain: |
|
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()]) |
|
else: |
|
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) |
|
|
|
if dataset == 'livec': |
|
self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms) |
|
elif dataset == 'koniq10k': |
|
self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms) |
|
elif dataset == 'bid': |
|
self.data = folders.BID(root=path, index=img_indx, transform=transforms) |
|
elif dataset == 'spaq': |
|
self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms) |
|
else: |
|
raise Exception("Only support livec, koniq10k, bid, spaq.") |
|
|
|
def get_data(self): |
|
dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8) |
|
return dataloader |