ShoufaChen's picture
init
137645c verified
raw
history blame
7.41 kB
# ImageNet-1K Dataset and DataLoader
from einops import rearrange
import torch
import torch.nn.functional as F
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image
import math
from functools import partial
import numpy as np
import random
from diffusers.models.embeddings import get_2d_rotary_pos_embed
# https://github.com/facebookresearch/DiT/blob/main/train.py#L85
def center_crop_arr(pil_image, image_size):
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def collate_fn(examples, config, noise_scheduler_copy):
patch_size = config.model.params.patch_size
pixel_values = torch.stack([eg[0] for eg in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [eg[1] for eg in examples]
batch_size = len(examples)
stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1)
stage_indices = stage_indices[:batch_size]
random.shuffle(stage_indices)
stage_indices = torch.tensor(stage_indices, dtype=torch.int32)
orig_height, orig_width = pixel_values.shape[-2:]
timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,))
sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], []
for stage_idx in range(config.scheduler.num_stages):
corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1
stage_select_indices = timesteps[stage_indices == corrected_stage_idx]
Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float()
batch_size_select = Timesteps.shape[0]
pixel_values_select = pixel_values[stage_indices == corrected_stage_idx]
input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx]
end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx)
################ build model input ################
start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx]
pixel_values_end = pixel_values_select
pixel_values_start = pixel_values_select
if stage_idx > 0:
# pixel_values_end
for downsample_idx in range(1, stage_idx + 1):
pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
# pixel_values_start
for downsample_idx in range(1, stage_idx + 2):
pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
# upsample pixel_values_start
pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest")
noise = torch.randn_like(pixel_values_end)
pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise
pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise
target = pixel_values_end - pixel_values_start
t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten()
while len(t_select.shape) < pixel_values_start.ndim:
t_select = t_select.unsqueeze(-1)
xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start
target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
pos_embed = get_2d_rotary_pos_embed(
embed_dim=config.model.params.attention_head_dim,
crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)),
grid_size=(end_height // patch_size, end_width // patch_size),
)
seq_len = (end_height // patch_size) * (end_width // patch_size)
assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list"
sample_list.append(xt)
target_list.append(target)
pos_embed_list.extend([pos_embed] * batch_size_select)
seq_len_list.extend([seq_len] * batch_size_select)
timestep_list.append(Timesteps)
input_ids_list.extend(input_ids_select)
pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format)
target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format)
pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float()
cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32)
latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32)
return {
"pixel_values": pixel_values,
"input_ids": input_ids_list,
"pos_embed": pos_embed,
"cumsum_q_len": cumsum_q_len,
"batch_latent_size": latent_size_list,
"seqlen_list_q": seq_len_list,
"cumsum_kv_len": None,
"batch_kv_len": None,
"timesteps": torch.cat(timestep_list, dim=0),
"target_values": target_values,
}
def build_imagenet_loader(config, noise_scheduler_copy):
if config.data.center_crop:
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
else:
transform = transforms.Compose([
transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.RandomCrop(config.data.resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
dataset = ImageFolder(config.data.root, transform=transform)
sampler = DistributedSampler(
dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
shuffle=True,
seed=config.seed,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=config.data.batch_size,
collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy),
shuffle=False,
sampler=sampler,
num_workers=config.data.num_workers,
drop_last=True,
)
return loader