Spaces:
Running
on
Zero
Running
on
Zero
# ImageNet-1K Dataset and DataLoader | |
from einops import rearrange | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data.distributed import DistributedSampler | |
from torchvision.datasets import ImageFolder | |
from torchvision import transforms | |
from PIL import Image | |
import math | |
from functools import partial | |
import numpy as np | |
import random | |
from diffusers.models.embeddings import get_2d_rotary_pos_embed | |
# https://github.com/facebookresearch/DiT/blob/main/train.py#L85 | |
def center_crop_arr(pil_image, image_size): | |
while min(*pil_image.size) >= 2 * image_size: | |
pil_image = pil_image.resize( | |
tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
) | |
scale = image_size / min(*pil_image.size) | |
pil_image = pil_image.resize( | |
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
) | |
arr = np.array(pil_image) | |
crop_y = (arr.shape[0] - image_size) // 2 | |
crop_x = (arr.shape[1] - image_size) // 2 | |
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) | |
def collate_fn(examples, config, noise_scheduler_copy): | |
patch_size = config.model.params.patch_size | |
pixel_values = torch.stack([eg[0] for eg in examples]) | |
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | |
input_ids = [eg[1] for eg in examples] | |
batch_size = len(examples) | |
stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1) | |
stage_indices = stage_indices[:batch_size] | |
random.shuffle(stage_indices) | |
stage_indices = torch.tensor(stage_indices, dtype=torch.int32) | |
orig_height, orig_width = pixel_values.shape[-2:] | |
timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,)) | |
sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], [] | |
for stage_idx in range(config.scheduler.num_stages): | |
corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1 | |
stage_select_indices = timesteps[stage_indices == corrected_stage_idx] | |
Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float() | |
batch_size_select = Timesteps.shape[0] | |
pixel_values_select = pixel_values[stage_indices == corrected_stage_idx] | |
input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx] | |
end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx) | |
################ build model input ################ | |
start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx] | |
pixel_values_end = pixel_values_select | |
pixel_values_start = pixel_values_select | |
if stage_idx > 0: | |
# pixel_values_end | |
for downsample_idx in range(1, stage_idx + 1): | |
pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear") | |
# pixel_values_start | |
for downsample_idx in range(1, stage_idx + 2): | |
pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear") | |
# upsample pixel_values_start | |
pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest") | |
noise = torch.randn_like(pixel_values_end) | |
pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise | |
pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise | |
target = pixel_values_end - pixel_values_start | |
t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten() | |
while len(t_select.shape) < pixel_values_start.ndim: | |
t_select = t_select.unsqueeze(-1) | |
xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start | |
target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size) | |
xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size) | |
pos_embed = get_2d_rotary_pos_embed( | |
embed_dim=config.model.params.attention_head_dim, | |
crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)), | |
grid_size=(end_height // patch_size, end_width // patch_size), | |
) | |
seq_len = (end_height // patch_size) * (end_width // patch_size) | |
assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list" | |
sample_list.append(xt) | |
target_list.append(target) | |
pos_embed_list.extend([pos_embed] * batch_size_select) | |
seq_len_list.extend([seq_len] * batch_size_select) | |
timestep_list.append(Timesteps) | |
input_ids_list.extend(input_ids_select) | |
pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format) | |
target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format) | |
pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float() | |
cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32) | |
latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32) | |
return { | |
"pixel_values": pixel_values, | |
"input_ids": input_ids_list, | |
"pos_embed": pos_embed, | |
"cumsum_q_len": cumsum_q_len, | |
"batch_latent_size": latent_size_list, | |
"seqlen_list_q": seq_len_list, | |
"cumsum_kv_len": None, | |
"batch_kv_len": None, | |
"timesteps": torch.cat(timestep_list, dim=0), | |
"target_values": target_values, | |
} | |
def build_imagenet_loader(config, noise_scheduler_copy): | |
if config.data.center_crop: | |
transform = transforms.Compose([ | |
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
else: | |
transform = transforms.Compose([ | |
transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS), | |
transforms.RandomCrop(config.data.resolution), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
dataset = ImageFolder(config.data.root, transform=transform) | |
sampler = DistributedSampler( | |
dataset, | |
num_replicas=torch.distributed.get_world_size(), | |
rank=torch.distributed.get_rank(), | |
shuffle=True, | |
seed=config.seed, | |
) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=config.data.batch_size, | |
collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy), | |
shuffle=False, | |
sampler=sampler, | |
num_workers=config.data.num_workers, | |
drop_last=True, | |
) | |
return loader | |