File size: 1,367 Bytes
0c8d55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import json
from torch.utils.data import Dataset


class GenAIBench_Image(Dataset):
    # GenAIBench with 527 prompts
    def __init__(
            self,
            root_dir,
            meta_dir,
            ):
        self.meta_dir = meta_dir
        self.root_dir = root_dir
        self.models = 'custom'
        self.dataset = json.load(open(os.path.join(self.meta_dir, f"genai_image.json"), 'r'))
        print(f"Loaded dataset: genai_image.json")
        
        self.images = [] # list of images
        self.prompt_to_images = {}
        for prompt_idx in self.dataset:
            self.images.append({
                'prompt_idx': prompt_idx,
                'prompt': self.dataset[prompt_idx]['prompt'],
                'model': self.models,
                'image': os.path.join(self.root_dir, f"{int(prompt_idx):09d}.jpg"),
            })
            if prompt_idx not in self.prompt_to_images:
                self.prompt_to_images[prompt_idx] = []
            self.prompt_to_images[prompt_idx].append(len(self.images) - 1)
                
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        item = self.images[idx]
        image_paths = [item['image']]
        image = image_paths
        texts = [str(item['prompt'])]
        item = {"images": image, "texts": texts}
        return item