File size: 2,201 Bytes
5b2ab1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Any, List, Optional, Tuple, Union
import os
from PIL import Image
from random import randint, choices

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from diffusers.utils import load_image


class ImageFolderDataset(Dataset):
    """Dataset class for loading images and prompts from a folder and file path.

    Args:
        images_root (str):
            Path to the folder containing images.
        prompts_path (str):
            Path to the file containing prompts.
        image_size (Tuple[int, int]):
            Size of the images to be loaded.
        extensions (Tuple[str]):
            Tuple of valid image extensions.
    """

    def __init__(
        self,
        images_root: str,
        prompts_path: Optional[str] = None,
        image_size: Tuple[int, int] = (512, 512),
        extensions: Tuple[str] = (".jpg", ".jpeg", ".png", ".webp"),
    ) -> None:
        super().__init__()
        self.image_size = image_size

        self.images_paths, self.prompts = self._make_dataset(
            images_root=images_root, extensions=extensions, prompts_path=prompts_path
        )

        self.to_tensor = transforms.ToTensor()

    def _make_dataset(
        self,
        images_root: str,
        extensions: Tuple[str],
        prompts_path: Optional[str] = None,
    ) -> Tuple[List[str], Union[None, List[str]]]:
        images_paths = []
        for root, _, fnames in sorted(os.walk(images_root)):
            for fname in sorted(fnames):
                if fname.lower().endswith(extensions):
                    images_paths.append(os.path.join(root, fname))

        if prompts_path is not None:
            with open(prompts_path, "r") as f:
                prompts = f.readlines()
        else:
            prompts = None

        return images_paths, prompts

    def __len__(self) -> int:
        return len(self.images_paths)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, Union[None, str]]:
        image = load_image(self.images_paths[idx]).resize(self.image_size)
        prompt = self.prompts[idx] if self.prompts is not None else None

        return self.to_tensor(image), prompt