File size: 2,245 Bytes
c583015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image, ImageReadMode
import numpy as np


def denorm_img(img: torch.Tensor) -> torch.Tensor:
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
    return torch.clip(img * std + mean, min=0, max=1)


class StyleContentDataset(Dataset):
    def __init__(self, style_imgs, content_imgs, transform=None, normalize=None):
        self.style_imgs = style_imgs
        self.content_imgs = content_imgs
        self.transform = transform
        self.normalize = normalize

    def __len__(self):
        if len(self.style_imgs) < len(self.content_imgs):
            return len(self.style_imgs)
        else:
            return len(self.content_imgs)
    
    def __getitem__(self, idx):
        try:
            style = read_image(self.style_imgs[idx], ImageReadMode.RGB).float() / 255.0
            content = read_image(self.content_imgs[idx], ImageReadMode.RGB).float() / 255.0
        except RuntimeError:
            print(self.style_imgs[idx])
            print(self.content_imgs[idx])
            style = read_image(self.style_imgs[0], ImageReadMode.RGB).float() / 255.0
            content = read_image(self.content_imgs[0], ImageReadMode.RGB).float() / 255.0

        if self.normalize:
            style = self.normalize(style)
            content = self.normalize(content)
        
        if self.transform:
            style = self.transform(style)
            content = self.transform(content)
        
        return style, content
    

class DataStore():
    def __init__(self, dataset: StyleContentDataset, batch_size, shuffle=False):
        self.dataset = dataset
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2)
        self.iterator = iter(self.dataloader)

    def get(self):
        try:
           style, content = next(self.iterator)
        except (StopIteration):
            # print('| Repeating |')
            # np.random.shuffle(self.dataset.style_imgs)
            self.iterator = iter(self.dataloader)
            style, content = next(self.iterator)
        
        return style, content