Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import albumentations | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torch.utils.data import Dataset | |
class DalleTransformerPreprocessor(object): | |
def __init__(self, | |
size=256, | |
phase='train', | |
additional_targets=None): | |
self.size = size | |
self.phase = phase | |
# ddc: following dalle to use randomcrop | |
self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)], | |
additional_targets=additional_targets) | |
self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)], | |
additional_targets=additional_targets) | |
def __call__(self, image, **kargs): | |
""" | |
image: PIL.Image | |
""" | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image.astype(np.uint8)) | |
w, h = image.size | |
s_min = min(h, w) | |
if self.phase == 'train': | |
off_h = int(random.uniform(3*(h-s_min)//8, max(3*(h-s_min)//8+1, 5*(h-s_min)//8))) | |
off_w = int(random.uniform(3*(w-s_min)//8, max(3*(w-s_min)//8+1, 5*(w-s_min)//8))) | |
image = image.crop((off_w, off_h, off_w + s_min, off_h + s_min)) | |
# resize image | |
t_max = min(s_min, round(9/8*self.size)) | |
t_max = max(t_max, self.size) | |
t = int(random.uniform(self.size, t_max+1)) | |
image = image.resize((t, t)) | |
image = np.array(image).astype(np.uint8) | |
image = self.train_preprocessor(image=image) | |
else: | |
if w < h: | |
w_ = self.size | |
h_ = int(h * w_/w) | |
else: | |
h_ = self.size | |
w_ = int(w * h_/h) | |
image = image.resize((w_, h_)) | |
image = np.array(image).astype(np.uint8) | |
image = self.val_preprocessor(image=image) | |
return image | |
class CelebA(Dataset): | |
""" | |
This Dataset can be used for: | |
- image-only: setting 'conditions' = [] | |
- image and multi-modal 'conditions': setting conditions as the list of modalities you need | |
To toggle between 256 and 512 image resolution, simply change the 'image_folder' | |
""" | |
def __init__( | |
self, | |
phase='train', | |
size=512, | |
test_dataset_size=3000, | |
conditions=['seg_mask', 'text', 'sketch'], | |
image_folder='data/celeba/image/image_512_downsampled_from_hq_1024', | |
text_file='data/celeba/text/captions_hq_beard_and_age_2022-08-19.json', | |
mask_folder='data/celeba/mask/CelebAMask-HQ-mask-color-palette_32_nearest_downsampled_from_hq_512_one_hot_2d_tensor', | |
sketch_folder='data/celeba/sketch/sketch_1x1024_tensor', | |
): | |
self.transform = DalleTransformerPreprocessor(size=size, phase=phase) | |
self.conditions = conditions | |
self.image_folder = image_folder | |
# conditions directory | |
self.text_file = text_file | |
with open(self.text_file, 'r') as f: | |
self.text_file_content = json.load(f) | |
if 'seg_mask' in self.conditions: | |
self.mask_folder = mask_folder | |
if 'sketch' in self.conditions: | |
self.sketch_folder = sketch_folder | |
# list of valid image names & train test split | |
self.image_name_list = list(self.text_file_content.keys()) | |
# train test split | |
if phase == 'train': | |
self.image_name_list = self.image_name_list[:-test_dataset_size] | |
elif phase == 'test': | |
self.image_name_list = self.image_name_list[-test_dataset_size:] | |
else: | |
raise NotImplementedError | |
self.num = len(self.image_name_list) | |
def __len__(self): | |
return self.num | |
def __getitem__(self, index): | |
# ---------- (1) get image ---------- | |
image_name = self.image_name_list[index] | |
image_path = os.path.join(self.image_folder, image_name) | |
image = Image.open(image_path).convert('RGB') | |
image = np.array(image).astype(np.uint8) | |
image = self.transform(image=image)['image'] | |
image = image.astype(np.float32)/127.5 - 1.0 | |
# record into data entry | |
if len(self.conditions) == 1: | |
data = { | |
'image': image, | |
} | |
else: | |
data = { | |
'image': image, | |
'conditions': {} | |
} | |
# ---------- (2) get text ---------- | |
if 'text' in self.conditions: | |
text = self.text_file_content[image_name]["Beard_and_Age"].lower() | |
# record into data entry | |
if len(self.conditions) == 1: | |
data['caption'] = text | |
else: | |
data['conditions']['text'] = text | |
# ---------- (3) get mask ---------- | |
if 'seg_mask' in self.conditions: | |
mask_idx = image_name.split('.')[0] | |
mask_name = f'{mask_idx}.pt' | |
mask_path = os.path.join(self.mask_folder, mask_name) | |
mask_one_hot_tensor = torch.load(mask_path) | |
# record into data entry | |
if len(self.conditions) == 1: | |
data['seg_mask'] = mask_one_hot_tensor | |
else: | |
data['conditions']['seg_mask'] = mask_one_hot_tensor | |
# ---------- (4) get sketch ---------- | |
if 'sketch' in self.conditions: | |
sketch_idx = image_name.split('.')[0] | |
sketch_name = f'{sketch_idx}.pt' | |
sketch_path = os.path.join(self.sketch_folder, sketch_name) | |
sketch_one_hot_tensor = torch.load(sketch_path) | |
# record into data entry | |
if len(self.conditions) == 1: | |
data['sketch'] = sketch_one_hot_tensor | |
else: | |
data['conditions']['sketch'] = sketch_one_hot_tensor | |
data["image_name"] = image_name.split('.')[0] | |
return data | |
if __name__ == '__main__': | |
# The caption file only has 29999 captions: https://github.com/ziqihuangg/CelebA-Dialog/issues/1 | |
# Testing for `phase` | |
train_dataset = CelebA(phase="train") | |
test_dataset = CelebA(phase="test") | |
assert len(train_dataset)==26999 | |
assert len(test_dataset)==3000 | |
# Testing for `size` | |
size_512 = CelebA(size=512) | |
assert size_512[0]['image'].shape == (512, 512, 3) | |
assert size_512[0]["conditions"]['seg_mask'].shape == (19, 1024) | |
assert size_512[0]["conditions"]['sketch'].shape == (1, 1024) | |
size_512 = CelebA(size=256) | |
assert size_512[0]['image'].shape == (256, 256, 3) | |
assert size_512[0]["conditions"]['seg_mask'].shape == (19, 1024) | |
assert size_512[0]["conditions"]['sketch'].shape == (1, 1024) | |
# Testing for `conditions` | |
dataset = CelebA(conditions = ['seg_mask', 'text', 'sketch']) | |
image = dataset[0]["image"] | |
seg_mask= dataset[0]["conditions"]['seg_mask'] | |
sketch = dataset[0]["conditions"]['sketch'] | |
text = dataset[0]["conditions"]['text'] | |
# show image, seg_mask, sketch in 3x3 grid, and text in title | |
fig, ax = plt.subplots(1, 3, figsize=(12, 4)) | |
# Show image | |
ax[0].imshow((image + 1) / 2) | |
ax[0].set_title('Image') | |
ax[0].axis('off') | |
# # Show segmentation mask | |
seg_mask = torch.argmax(seg_mask, dim=0).reshape(32, 32).numpy().astype(np.uint8) | |
# resize to 512x512 using nearest neighbor interpolation | |
seg_mask = Image.fromarray(seg_mask).resize((512, 512), Image.NEAREST) | |
seg_mask = np.array(seg_mask) | |
ax[1].imshow(seg_mask, cmap='tab20') | |
ax[1].set_title('Segmentation Mask') | |
ax[1].axis('off') | |
# # # Show sketch | |
sketch = sketch.reshape(32, 32).numpy().astype(np.uint8) | |
# resize to 512x512 using nearest neighbor interpolation | |
sketch = Image.fromarray(sketch).resize((512, 512), Image.NEAREST) | |
sketch = np.array(sketch) | |
ax[2].imshow(sketch, cmap='gray') | |
ax[2].set_title('Sketch') | |
ax[2].axis('off') | |
# Add title with text | |
fig.suptitle(text, fontsize=16) | |
plt.tight_layout() | |
plt.savefig('celeba_sample.png') | |
# save seg_mask with name such as "27000.png, 270001.png, ..., 279999.png" of test dataset to "/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/evaluation/CollDiff/real_mask" | |
from tqdm import tqdm | |
for data in tqdm(test_dataset): | |
mask = torch.argmax(data["conditions"]['seg_mask'], dim=0).reshape(32, 32).numpy().astype(np.uint8) | |
mask = Image.fromarray(mask).resize((512, 512), Image.NEAREST) | |
mask.save(f"/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/evaluation/CollDiff/real_mask/{data['image_name']}.png") |