import json import os import random import albumentations import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image from torch.utils.data import Dataset class DalleTransformerPreprocessor(object): def __init__(self, size=256, phase='train', additional_targets=None): self.size = size self.phase = phase # ddc: following dalle to use randomcrop self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)], additional_targets=additional_targets) self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)], additional_targets=additional_targets) def __call__(self, image, **kargs): """ image: PIL.Image """ if isinstance(image, np.ndarray): image = Image.fromarray(image.astype(np.uint8)) w, h = image.size s_min = min(h, w) if self.phase == 'train': off_h = int(random.uniform(3*(h-s_min)//8, max(3*(h-s_min)//8+1, 5*(h-s_min)//8))) off_w = int(random.uniform(3*(w-s_min)//8, max(3*(w-s_min)//8+1, 5*(w-s_min)//8))) image = image.crop((off_w, off_h, off_w + s_min, off_h + s_min)) # resize image t_max = min(s_min, round(9/8*self.size)) t_max = max(t_max, self.size) t = int(random.uniform(self.size, t_max+1)) image = image.resize((t, t)) image = np.array(image).astype(np.uint8) image = self.train_preprocessor(image=image) else: if w < h: w_ = self.size h_ = int(h * w_/w) else: h_ = self.size w_ = int(w * h_/h) image = image.resize((w_, h_)) image = np.array(image).astype(np.uint8) image = self.val_preprocessor(image=image) return image class CelebA(Dataset): """ This Dataset can be used for: - image-only: setting 'conditions' = [] - image and multi-modal 'conditions': setting conditions as the list of modalities you need To toggle between 256 and 512 image resolution, simply change the 'image_folder' """ def __init__( self, phase='train', size=512, test_dataset_size=3000, conditions=['seg_mask', 'text', 'sketch'], image_folder='data/celeba/image/image_512_downsampled_from_hq_1024', text_file='data/celeba/text/captions_hq_beard_and_age_2022-08-19.json', mask_folder='data/celeba/mask/CelebAMask-HQ-mask-color-palette_32_nearest_downsampled_from_hq_512_one_hot_2d_tensor', sketch_folder='data/celeba/sketch/sketch_1x1024_tensor', ): self.transform = DalleTransformerPreprocessor(size=size, phase=phase) self.conditions = conditions self.image_folder = image_folder # conditions directory self.text_file = text_file with open(self.text_file, 'r') as f: self.text_file_content = json.load(f) if 'seg_mask' in self.conditions: self.mask_folder = mask_folder if 'sketch' in self.conditions: self.sketch_folder = sketch_folder # list of valid image names & train test split self.image_name_list = list(self.text_file_content.keys()) # train test split if phase == 'train': self.image_name_list = self.image_name_list[:-test_dataset_size] elif phase == 'test': self.image_name_list = self.image_name_list[-test_dataset_size:] else: raise NotImplementedError self.num = len(self.image_name_list) def __len__(self): return self.num def __getitem__(self, index): # ---------- (1) get image ---------- image_name = self.image_name_list[index] image_path = os.path.join(self.image_folder, image_name) image = Image.open(image_path).convert('RGB') image = np.array(image).astype(np.uint8) image = self.transform(image=image)['image'] image = image.astype(np.float32)/127.5 - 1.0 # record into data entry if len(self.conditions) == 1: data = { 'image': image, } else: data = { 'image': image, 'conditions': {} } # ---------- (2) get text ---------- if 'text' in self.conditions: text = self.text_file_content[image_name]["Beard_and_Age"].lower() # record into data entry if len(self.conditions) == 1: data['caption'] = text else: data['conditions']['text'] = text # ---------- (3) get mask ---------- if 'seg_mask' in self.conditions: mask_idx = image_name.split('.')[0] mask_name = f'{mask_idx}.pt' mask_path = os.path.join(self.mask_folder, mask_name) mask_one_hot_tensor = torch.load(mask_path) # record into data entry if len(self.conditions) == 1: data['seg_mask'] = mask_one_hot_tensor else: data['conditions']['seg_mask'] = mask_one_hot_tensor # ---------- (4) get sketch ---------- if 'sketch' in self.conditions: sketch_idx = image_name.split('.')[0] sketch_name = f'{sketch_idx}.pt' sketch_path = os.path.join(self.sketch_folder, sketch_name) sketch_one_hot_tensor = torch.load(sketch_path) # record into data entry if len(self.conditions) == 1: data['sketch'] = sketch_one_hot_tensor else: data['conditions']['sketch'] = sketch_one_hot_tensor data["image_name"] = image_name.split('.')[0] return data if __name__ == '__main__': # The caption file only has 29999 captions: https://github.com/ziqihuangg/CelebA-Dialog/issues/1 # Testing for `phase` train_dataset = CelebA(phase="train") test_dataset = CelebA(phase="test") assert len(train_dataset)==26999 assert len(test_dataset)==3000 # Testing for `size` size_512 = CelebA(size=512) assert size_512[0]['image'].shape == (512, 512, 3) assert size_512[0]["conditions"]['seg_mask'].shape == (19, 1024) assert size_512[0]["conditions"]['sketch'].shape == (1, 1024) size_512 = CelebA(size=256) assert size_512[0]['image'].shape == (256, 256, 3) assert size_512[0]["conditions"]['seg_mask'].shape == (19, 1024) assert size_512[0]["conditions"]['sketch'].shape == (1, 1024) # Testing for `conditions` dataset = CelebA(conditions = ['seg_mask', 'text', 'sketch']) image = dataset[0]["image"] seg_mask= dataset[0]["conditions"]['seg_mask'] sketch = dataset[0]["conditions"]['sketch'] text = dataset[0]["conditions"]['text'] # show image, seg_mask, sketch in 3x3 grid, and text in title fig, ax = plt.subplots(1, 3, figsize=(12, 4)) # Show image ax[0].imshow((image + 1) / 2) ax[0].set_title('Image') ax[0].axis('off') # # Show segmentation mask seg_mask = torch.argmax(seg_mask, dim=0).reshape(32, 32).numpy().astype(np.uint8) # resize to 512x512 using nearest neighbor interpolation seg_mask = Image.fromarray(seg_mask).resize((512, 512), Image.NEAREST) seg_mask = np.array(seg_mask) ax[1].imshow(seg_mask, cmap='tab20') ax[1].set_title('Segmentation Mask') ax[1].axis('off') # # # Show sketch sketch = sketch.reshape(32, 32).numpy().astype(np.uint8) # resize to 512x512 using nearest neighbor interpolation sketch = Image.fromarray(sketch).resize((512, 512), Image.NEAREST) sketch = np.array(sketch) ax[2].imshow(sketch, cmap='gray') ax[2].set_title('Sketch') ax[2].axis('off') # Add title with text fig.suptitle(text, fontsize=16) plt.tight_layout() plt.savefig('celeba_sample.png') # save seg_mask with name such as "27000.png, 270001.png, ..., 279999.png" of test dataset to "/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/evaluation/CollDiff/real_mask" from tqdm import tqdm for data in tqdm(test_dataset): mask = torch.argmax(data["conditions"]['seg_mask'], dim=0).reshape(32, 32).numpy().astype(np.uint8) mask = Image.fromarray(mask).resize((512, 512), Image.NEAREST) mask.save(f"/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/evaluation/CollDiff/real_mask/{data['image_name']}.png")