LinB203
init
0c8d55e
import os
import json
from torch.utils.data import Dataset
class GenAIBench_Image(Dataset):
# GenAIBench with 527 prompts
def __init__(
self,
root_dir,
meta_dir,
):
self.meta_dir = meta_dir
self.root_dir = root_dir
self.models = 'custom'
self.dataset = json.load(open(os.path.join(self.meta_dir, f"genai_image.json"), 'r'))
print(f"Loaded dataset: genai_image.json")
self.images = [] # list of images
self.prompt_to_images = {}
for prompt_idx in self.dataset:
self.images.append({
'prompt_idx': prompt_idx,
'prompt': self.dataset[prompt_idx]['prompt'],
'model': self.models,
'image': os.path.join(self.root_dir, f"{int(prompt_idx):09d}.jpg"),
})
if prompt_idx not in self.prompt_to_images:
self.prompt_to_images[prompt_idx] = []
self.prompt_to_images[prompt_idx].append(len(self.images) - 1)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
item = self.images[idx]
image_paths = [item['image']]
image = image_paths
texts = [str(item['prompt'])]
item = {"images": image, "texts": texts}
return item