File size: 7,407 Bytes
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# ImageNet-1K Dataset and DataLoader

from einops import rearrange
import torch
import torch.nn.functional as F
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image
import math
from functools import partial
import numpy as np
import random

from diffusers.models.embeddings import get_2d_rotary_pos_embed

# https://github.com/facebookresearch/DiT/blob/main/train.py#L85
def center_crop_arr(pil_image, image_size):
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


def collate_fn(examples, config, noise_scheduler_copy):
    patch_size = config.model.params.patch_size
    pixel_values = torch.stack([eg[0] for eg in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = [eg[1] for eg in examples]

    batch_size = len(examples)
    stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1)
    stage_indices = stage_indices[:batch_size]

    random.shuffle(stage_indices)
    stage_indices = torch.tensor(stage_indices, dtype=torch.int32)
    orig_height, orig_width = pixel_values.shape[-2:]
    timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,))

    sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], []
    for stage_idx in range(config.scheduler.num_stages):
        corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1
        stage_select_indices = timesteps[stage_indices == corrected_stage_idx]
        Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float()
        batch_size_select = Timesteps.shape[0]
        pixel_values_select = pixel_values[stage_indices == corrected_stage_idx]
        input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx]

        end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx)

        ################ build model input ################
        start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx]

        pixel_values_end = pixel_values_select
        pixel_values_start = pixel_values_select
        if stage_idx > 0:
            # pixel_values_end
            for downsample_idx in range(1, stage_idx + 1):
                pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")

        # pixel_values_start
        for downsample_idx in range(1, stage_idx + 2):
            pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
        # upsample pixel_values_start
        pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest")

        noise = torch.randn_like(pixel_values_end)
        pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise
        pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise
        target = pixel_values_end - pixel_values_start

        t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten()
        while len(t_select.shape) < pixel_values_start.ndim:
            t_select = t_select.unsqueeze(-1)
        xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start
        
        target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
        xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)

        pos_embed = get_2d_rotary_pos_embed(
            embed_dim=config.model.params.attention_head_dim,
            crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)),
            grid_size=(end_height // patch_size, end_width // patch_size),
        )
        seq_len = (end_height // patch_size) * (end_width // patch_size)
        assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list"
        sample_list.append(xt)
        target_list.append(target)
        pos_embed_list.extend([pos_embed] * batch_size_select)
        seq_len_list.extend([seq_len] * batch_size_select)
        timestep_list.append(Timesteps)
        input_ids_list.extend(input_ids_select)

    pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format)
    target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format)
    pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float()
    cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32)
    latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32)

    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids_list,
        "pos_embed": pos_embed,
        "cumsum_q_len": cumsum_q_len,
        "batch_latent_size": latent_size_list,
        "seqlen_list_q": seq_len_list,
        "cumsum_kv_len": None,
        "batch_kv_len": None,
        "timesteps": torch.cat(timestep_list, dim=0),
        "target_values": target_values,
    }


def build_imagenet_loader(config, noise_scheduler_copy):
    if config.data.center_crop:
        transform = transforms.Compose([
            transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS),
            transforms.RandomCrop(config.data.resolution),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
    dataset = ImageFolder(config.data.root, transform=transform)
    sampler = DistributedSampler(
        dataset,
        num_replicas=torch.distributed.get_world_size(),
        rank=torch.distributed.get_rank(),
        shuffle=True,
        seed=config.seed,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.data.batch_size,
        collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy),
        shuffle=False,
        sampler=sampler,
        num_workers=config.data.num_workers,
        drop_last=True,
    )
    return loader