PixelFlow-Text2Image / pixelflow /pipeline_pixelflow.py
ShoufaChen's picture
init
137645c verified
raw
history blame
12.6 kB
from einops import rearrange
import math
from typing import List, Optional, Union
import time
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.embeddings import get_2d_rotary_pos_embed
class PixelFlowPipeline:
def __init__(
self,
scheduler,
transformer,
text_encoder=None,
tokenizer=None,
max_token_length=512,
):
super().__init__()
self.class_cond = text_encoder is None or tokenizer is None
self.scheduler = scheduler
self.transformer = transformer
self.patch_size = transformer.patch_size
self.head_dim = transformer.attention_head_dim
self.num_stages = scheduler.num_stages
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.max_token_length = max_token_length
@torch.autocast("cuda", enabled=False)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Union[str, List[str]] = "",
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
use_attention_mask: bool = False,
max_length: int = 512,
):
# Determine the batch size and normalize prompt input to a list
if prompt is not None:
if isinstance(prompt, str):
prompt = [prompt]
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Process prompt embeddings if not provided
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=prompt_attention_mask if use_attention_mask else None
)[0]
# Determine dtype from available encoder
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
# Move prompt embeddings to desired dtype and device
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
# Handle classifier-free guidance for negative prompts
if do_classifier_free_guidance and negative_prompt_embeds is None:
# Normalize negative prompt to list and validate length
if isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif isinstance(negative_prompt, list):
if len(negative_prompt) != batch_size:
raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}")
uncond_tokens = negative_prompt
else:
raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}")
# Tokenize and encode negative prompts
uncond_inputs = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=prompt_embeds.shape[1],
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_input_ids = uncond_inputs.input_ids.to(device)
negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
negative_input_ids,
attention_mask=negative_prompt_attention_mask if use_attention_mask else None
)[0]
if do_classifier_free_guidance:
# Duplicate negative prompt embeddings and attention mask for each generation
seq_len_neg = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
# Concatenate negative and positive embeddings and their masks
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
return prompt_embeds, prompt_attention_mask
def sample_block_noise(self, bs, ch, height, width, eps=1e-6)):
gamma = self.scheduler.gamma
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4))
block_number = bs * ch * (height // 2) * (width // 2)
noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2)
return noise
@torch.no_grad()
def __call__(
self,
prompt,
height,
width,
num_inference_steps=30,
guidance_scale=4.0,
num_images_per_prompt=1,
device=None,
shift=1.0,
use_ode_dopri5=False,
):
if isinstance(num_inference_steps, int):
num_inference_steps = [num_inference_steps] * self.num_stages
if use_ode_dopri5:
assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now"
from pixelflow.solver_ode_wrapper import ODE
sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample
else:
# default Euler
sample_fn = None
self._guidance_scale = guidance_scale
batch_size = len(prompt)
if self.class_cond:
prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device)
negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
else:
prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
guidance_scale > 1,
"",
prompt_embeds=None,
negative_prompt_embeds=None,
use_attention_mask=True,
max_length=self.max_token_length,
)
init_factor = 2 ** (self.num_stages - 1)
height, width = height // init_factor, width // init_factor
shape = (batch_size * num_images_per_prompt, 3, height, width)
latents = randn_tensor(shape, device=device, dtype=torch.float32)
for stage_idx in range(self.num_stages):
stage_start = time.time()
# Set the number of inference steps for the current stage
self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
Timesteps = self.scheduler.Timesteps
if stage_idx > 0:
height, width = height * 2, width * 2
latents = F.interpolate(latents, size=(height, width), mode='nearest')
original_start_t = self.scheduler.original_start_t[stage_idx]
gamma = self.scheduler.gamma
alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
beta = alpha * (1 - original_start_t) / math.sqrt(- gamma)
# bs, ch, height, width = latents.shape
noise = self.sample_block_noise(*latents.shape)
noise = noise.to(device=device, dtype=latents.dtype)
latents = alpha * latents + beta * noise
size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device)
pos_embed = get_2d_rotary_pos_embed(
embed_dim=self.head_dim,
crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)),
grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size),
)
rope_pos = torch.stack(pos_embed, -1)
if sample_fn is not None:
# dopri5
model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos)
if stage_idx == 0:
latents = torch.cat([latents] * 2)
stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item()
stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item()
latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1]
if stage_idx == self.num_stages - 1:
latents = latents[:latents.shape[0] // 2]
else:
# euler
for T in Timesteps:
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
if self.class_cond:
noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos)
else:
encoder_hidden_states = prompt_embeds
encoder_attention_mask = prompt_attention_mask
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
latent_size=size_tensor,
pos_embed=rope_pos,
)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(model_output=noise_pred, sample=latents)
stage_end = time.time()
samples = (latents / 2 + 0.5).clamp(0, 1)
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
return samples
@property
def device(self):
return next(self.transformer.parameters()).device
@property
def dtype(self):
return next(self.transformer.parameters()).dtype
def guidance_scale(self, step=None, stage_idx=None):
if not self.class_cond:
return self._guidance_scale
scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1}
return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 0