File size: 12,621 Bytes
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1718638
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from einops import rearrange
import math
from typing import List, Optional, Union
import time
import torch
import torch.nn.functional as F

from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.embeddings import get_2d_rotary_pos_embed


class PixelFlowPipeline:
    def __init__(
        self,
        scheduler,
        transformer,
        text_encoder=None,
        tokenizer=None,
        max_token_length=512,
    ):
        super().__init__()
        self.class_cond = text_encoder is None or tokenizer is None
        self.scheduler = scheduler
        self.transformer = transformer
        self.patch_size = transformer.patch_size
        self.head_dim = transformer.attention_head_dim
        self.num_stages = scheduler.num_stages

        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.max_token_length = max_token_length

    @torch.autocast("cuda", enabled=False)
    def encode_prompt(
        self,
        prompt: Union[str, List[str]],
        device: Optional[torch.device] = None,
        num_images_per_prompt: int = 1,
        do_classifier_free_guidance: bool = True,
        negative_prompt: Union[str, List[str]] = "",
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        prompt_attention_mask: Optional[torch.FloatTensor] = None,
        negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
        use_attention_mask: bool = False,
        max_length: int = 512,
    ):
        # Determine the batch size and normalize prompt input to a list
        if prompt is not None:
            if isinstance(prompt, str):
                prompt = [prompt]
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        # Process prompt embeddings if not provided
        if prompt_embeds is None:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                add_special_tokens=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids.to(device)
            prompt_attention_mask = text_inputs.attention_mask.to(device)
            prompt_embeds = self.text_encoder(
                text_input_ids,
                attention_mask=prompt_attention_mask if use_attention_mask else None
            )[0]

        # Determine dtype from available encoder
        if self.text_encoder is not None:
            dtype = self.text_encoder.dtype
        elif self.transformer is not None:
            dtype = self.transformer.dtype
        else:
            dtype = None

        # Move prompt embeddings to desired dtype and device
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
        prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)

        # Handle classifier-free guidance for negative prompts
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            # Normalize negative prompt to list and validate length
            if isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt] * batch_size
            elif isinstance(negative_prompt, list):
                if len(negative_prompt) != batch_size:
                    raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}")
                uncond_tokens = negative_prompt
            else:
                raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}")

            # Tokenize and encode negative prompts
            uncond_inputs = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=prompt_embeds.shape[1],
                truncation=True,
                return_attention_mask=True,
                add_special_tokens=True,
                return_tensors="pt",
            )
            negative_input_ids = uncond_inputs.input_ids.to(device)
            negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
            negative_prompt_embeds = self.text_encoder(
                negative_input_ids,
                attention_mask=negative_prompt_attention_mask if use_attention_mask else None
            )[0]

        if do_classifier_free_guidance:
            # Duplicate negative prompt embeddings and attention mask for each generation
            seq_len_neg = negative_prompt_embeds.shape[1]
            negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
            negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
        else:
            negative_prompt_embeds = None
            negative_prompt_attention_mask = None

        # Concatenate negative and positive embeddings and their masks
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

        return prompt_embeds, prompt_attention_mask

    def sample_block_noise(self, bs, ch, height, width, eps=1e-6):
        gamma = self.scheduler.gamma
        dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4))
        block_number = bs * ch * (height // 2) * (width // 2)
        noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
        noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2)
        return noise

    @torch.no_grad()
    def __call__(
        self,
        prompt,
        height,
        width,
        num_inference_steps=30,
        guidance_scale=4.0,
        num_images_per_prompt=1,
        device=None,
        shift=1.0,
        use_ode_dopri5=False,
    ):
        if isinstance(num_inference_steps, int):
            num_inference_steps = [num_inference_steps] * self.num_stages

        if use_ode_dopri5:
            assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now"
            from pixelflow.solver_ode_wrapper import ODE
            sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample
        else:
            # default Euler
            sample_fn = None

        self._guidance_scale = guidance_scale
        batch_size = len(prompt)
        if self.class_cond:
            prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device)
            negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds)
            if self.do_classifier_free_guidance:
                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        else:
            prompt_embeds, prompt_attention_mask = self.encode_prompt(
                prompt,
                device,
                num_images_per_prompt,
                guidance_scale > 1,
                "",
                prompt_embeds=None,
                negative_prompt_embeds=None,
                use_attention_mask=True,
                max_length=self.max_token_length,
            )

        init_factor = 2 ** (self.num_stages - 1)
        height, width =  height // init_factor, width // init_factor
        shape = (batch_size * num_images_per_prompt, 3, height, width)
        latents = randn_tensor(shape, device=device, dtype=torch.float32)

        for stage_idx in range(self.num_stages):
            stage_start = time.time()
            # Set the number of inference steps for the current stage
            self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
            Timesteps = self.scheduler.Timesteps

            if stage_idx > 0:
                height, width = height * 2, width * 2
                latents = F.interpolate(latents, size=(height, width), mode='nearest')
                original_start_t = self.scheduler.original_start_t[stage_idx]
                gamma = self.scheduler.gamma
                alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
                beta = alpha * (1 - original_start_t) / math.sqrt(- gamma)

                # bs, ch, height, width = latents.shape
                noise = self.sample_block_noise(*latents.shape)
                noise = noise.to(device=device, dtype=latents.dtype)
                latents = alpha * latents + beta * noise

            size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device)
            pos_embed = get_2d_rotary_pos_embed(
                embed_dim=self.head_dim,
                crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)),
                grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size),
            )
            rope_pos = torch.stack(pos_embed, -1)

            if sample_fn is not None:
                # dopri5
                model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos)
                if stage_idx == 0:
                    latents = torch.cat([latents] * 2)
                stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item()
                stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item()
                latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1]
                if stage_idx == self.num_stages - 1:
                    latents = latents[:latents.shape[0] // 2]
            else:
                # euler
                for T in Timesteps:
                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                    timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
                    if self.class_cond:
                        noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos)
                    else:
                        encoder_hidden_states = prompt_embeds
                        encoder_attention_mask = prompt_attention_mask

                        noise_pred = self.transformer(
                            latent_model_input,
                            encoder_hidden_states=encoder_hidden_states,
                            encoder_attention_mask=encoder_attention_mask,
                            timestep=timestep,
                            latent_size=size_tensor,
                            pos_embed=rope_pos,
                        )

                    if self.do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond)

                    latents = self.scheduler.step(model_output=noise_pred, sample=latents)
            stage_end = time.time()

        samples = (latents / 2 + 0.5).clamp(0, 1)
        samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
        return samples

    @property
    def device(self):
        return next(self.transformer.parameters()).device

    @property
    def dtype(self):
        return next(self.transformer.parameters()).dtype

    def guidance_scale(self, step=None, stage_idx=None):
        if not self.class_cond:
            return self._guidance_scale
        scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1}
        return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1

    @property
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 0