Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,621 Bytes
137645c 1718638 137645c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
from einops import rearrange
import math
from typing import List, Optional, Union
import time
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.embeddings import get_2d_rotary_pos_embed
class PixelFlowPipeline:
def __init__(
self,
scheduler,
transformer,
text_encoder=None,
tokenizer=None,
max_token_length=512,
):
super().__init__()
self.class_cond = text_encoder is None or tokenizer is None
self.scheduler = scheduler
self.transformer = transformer
self.patch_size = transformer.patch_size
self.head_dim = transformer.attention_head_dim
self.num_stages = scheduler.num_stages
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.max_token_length = max_token_length
@torch.autocast("cuda", enabled=False)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Union[str, List[str]] = "",
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
use_attention_mask: bool = False,
max_length: int = 512,
):
# Determine the batch size and normalize prompt input to a list
if prompt is not None:
if isinstance(prompt, str):
prompt = [prompt]
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Process prompt embeddings if not provided
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=prompt_attention_mask if use_attention_mask else None
)[0]
# Determine dtype from available encoder
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
# Move prompt embeddings to desired dtype and device
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
# Handle classifier-free guidance for negative prompts
if do_classifier_free_guidance and negative_prompt_embeds is None:
# Normalize negative prompt to list and validate length
if isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif isinstance(negative_prompt, list):
if len(negative_prompt) != batch_size:
raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}")
uncond_tokens = negative_prompt
else:
raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}")
# Tokenize and encode negative prompts
uncond_inputs = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=prompt_embeds.shape[1],
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_input_ids = uncond_inputs.input_ids.to(device)
negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
negative_input_ids,
attention_mask=negative_prompt_attention_mask if use_attention_mask else None
)[0]
if do_classifier_free_guidance:
# Duplicate negative prompt embeddings and attention mask for each generation
seq_len_neg = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
# Concatenate negative and positive embeddings and their masks
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
return prompt_embeds, prompt_attention_mask
def sample_block_noise(self, bs, ch, height, width, eps=1e-6):
gamma = self.scheduler.gamma
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4))
block_number = bs * ch * (height // 2) * (width // 2)
noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2)
return noise
@torch.no_grad()
def __call__(
self,
prompt,
height,
width,
num_inference_steps=30,
guidance_scale=4.0,
num_images_per_prompt=1,
device=None,
shift=1.0,
use_ode_dopri5=False,
):
if isinstance(num_inference_steps, int):
num_inference_steps = [num_inference_steps] * self.num_stages
if use_ode_dopri5:
assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now"
from pixelflow.solver_ode_wrapper import ODE
sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample
else:
# default Euler
sample_fn = None
self._guidance_scale = guidance_scale
batch_size = len(prompt)
if self.class_cond:
prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device)
negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
else:
prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
guidance_scale > 1,
"",
prompt_embeds=None,
negative_prompt_embeds=None,
use_attention_mask=True,
max_length=self.max_token_length,
)
init_factor = 2 ** (self.num_stages - 1)
height, width = height // init_factor, width // init_factor
shape = (batch_size * num_images_per_prompt, 3, height, width)
latents = randn_tensor(shape, device=device, dtype=torch.float32)
for stage_idx in range(self.num_stages):
stage_start = time.time()
# Set the number of inference steps for the current stage
self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
Timesteps = self.scheduler.Timesteps
if stage_idx > 0:
height, width = height * 2, width * 2
latents = F.interpolate(latents, size=(height, width), mode='nearest')
original_start_t = self.scheduler.original_start_t[stage_idx]
gamma = self.scheduler.gamma
alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
beta = alpha * (1 - original_start_t) / math.sqrt(- gamma)
# bs, ch, height, width = latents.shape
noise = self.sample_block_noise(*latents.shape)
noise = noise.to(device=device, dtype=latents.dtype)
latents = alpha * latents + beta * noise
size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device)
pos_embed = get_2d_rotary_pos_embed(
embed_dim=self.head_dim,
crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)),
grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size),
)
rope_pos = torch.stack(pos_embed, -1)
if sample_fn is not None:
# dopri5
model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos)
if stage_idx == 0:
latents = torch.cat([latents] * 2)
stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item()
stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item()
latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1]
if stage_idx == self.num_stages - 1:
latents = latents[:latents.shape[0] // 2]
else:
# euler
for T in Timesteps:
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
if self.class_cond:
noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos)
else:
encoder_hidden_states = prompt_embeds
encoder_attention_mask = prompt_attention_mask
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
latent_size=size_tensor,
pos_embed=rope_pos,
)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(model_output=noise_pred, sample=latents)
stage_end = time.time()
samples = (latents / 2 + 0.5).clamp(0, 1)
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
return samples
@property
def device(self):
return next(self.transformer.parameters()).device
@property
def dtype(self):
return next(self.transformer.parameters()).dtype
def guidance_scale(self, step=None, stage_idx=None):
if not self.class_cond:
return self._guidance_scale
scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1}
return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 0
|