File size: 1,436 Bytes
07e1105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torchvision

from utils.dataset import folders
from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip

class Data_Loader(object):
    """Dataset class for IQA databases"""

    def __init__(self, config, path, img_indx, istrain=True):

        self.batch_size = config.batch_size
        self.istrain = istrain
        dataset = config.dataset
        patch_size = config.patch_size

        # Train transforms
        if istrain:
            transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()])
        else:
            transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])

        if dataset == 'livec':
            self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms)
        elif dataset == 'koniq10k':
            self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms)
        elif dataset == 'bid':
            self.data = folders.BID(root=path, index=img_indx, transform=transforms)
        elif dataset == 'spaq':
            self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms)
        else:
            raise Exception("Only support livec, koniq10k, bid, spaq.")

    def get_data(self):
        dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8)
        return dataloader