Tony Lian
Update: add attention guidance and refactor the code
89f6983
raw
history blame
28.6 kB
import torch
from tqdm import tqdm
from utils import guidance, schedule, boxdiff
import utils
from PIL import Image
import gc
import numpy as np
from .attention import GatedSelfAttentionDense
from .models import process_input_embeddings, torch_device
import warnings
# All keys: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
def latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, loss_scale = 30, loss_threshold = 0.2, max_iter = 5, max_index_step = 10, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, clear_cache=False, **kwargs):
iteration = 0
if index < max_index_step:
if isinstance(max_iter, list):
if len(max_iter) > index:
max_iter = max_iter[index]
else:
max_iter = max_iter[-1]
if verbose:
print(f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}")
while (loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step):
saved_attn = {}
full_cross_attention_kwargs = {
'save_attn_to_dict': saved_attn,
'save_keys': guidance_attn_keys,
}
if cross_attention_kwargs is not None:
full_cross_attention_kwargs.update(cross_attention_kwargs)
latents.requires_grad_(True)
latent_model_input = latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
unet(latent_model_input, t, encoder_hidden_states=cond_embeddings, return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs)
# TODO: could return the attention maps for the required blocks only and not necessarily the final output
# update latents with guidance
loss = guidance.compute_ca_lossv3(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys, ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * loss_scale
if torch.isnan(loss):
print("**Loss is NaN**")
del full_cross_attention_kwargs, saved_attn
# call gc.collect() here may release some memory
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
latents.requires_grad_(False)
if hasattr(scheduler, 'sigmas'):
latents = latents - grad_cond * scheduler.sigmas[index] ** 2
elif hasattr(scheduler, 'alphas_cumprod'):
warnings.warn("Using guidance scaled with alphas_cumprod")
# Scaling with classifier guidance
alpha_prod_t = scheduler.alphas_cumprod[t]
# Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
# DDIM: https://arxiv.org/pdf/2010.02502.pdf
scale = (1 - alpha_prod_t) ** (0.5)
latents = latents - scale * grad_cond
else:
# NOTE: no scaling is performed
warnings.warn("No scaling in guidance is performed")
latents = latents - grad_cond
iteration += 1
if clear_cache:
utils.free_memory()
if verbose:
print(f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}")
return latents, loss
@torch.no_grad()
def encode(model_dict, image, generator):
"""
image should be a PIL object or numpy array with range 0 to 255
"""
vae, dtype = model_dict.vae, model_dict.dtype
if isinstance(image, Image.Image):
w, h = image.size
assert w % 8 == 0 and h % 8 == 0, f"h ({h}) and w ({w}) should be a multiple of 8"
# w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
# image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :]
image = np.array(image)
if isinstance(image, np.ndarray):
assert image.dtype == np.uint8, f"Should have dtype uint8 (dtype: {image.dtype})"
image = image.astype(np.float32) / 255.0
image = image[None, ...]
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
assert isinstance(image, torch.Tensor), f"type of image: {type(image)}"
image = image.to(device=torch_device, dtype=dtype)
latents = vae.encode(image).latent_dist.sample(generator)
latents = vae.config.scaling_factor * latents
return latents
@torch.no_grad()
def decode(vae, latents):
# scale and decode the image latents with vae
scaled_latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(scaled_latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
return images
def generate_semantic_guidance(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, object_positions, guidance_scale = 7.5, semantic_guidance_kwargs=None,
return_cross_attn=False, return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None, offload_guidance_cross_attn_to_cpu=False,
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True, return_box_vis=False, show_progress=True, save_all_latents=False,
dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2, use_boxdiff=False):
"""
object_positions: object indices in text tokens
return_cross_attn: should be deprecated. Use `return_saved_cross_attn` and the new format.
"""
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
# Just in case that we have in-place ops
latents = latents.clone()
if save_all_latents:
# offload to cpu to save space
if offload_latents_to_cpu:
latents_all = [latents.cpu()]
else:
latents_all = [latents]
scheduler.set_timesteps(num_inference_steps)
if fast_after_steps is not None:
scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
if dynamic_num_inference_steps:
original_num_inference_steps = scheduler.num_inference_steps
cross_attention_probs_down = []
cross_attention_probs_mid = []
cross_attention_probs_up = []
loss = torch.tensor(10000.)
# TODO: we can also save necessary tokens only to save memory.
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
guidance_cross_attention_kwargs = {
'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
'enable_flash_attn': False
}
if return_saved_cross_attn:
saved_attns = []
main_cross_attention_kwargs = {
'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
'return_cond_ca_only': return_cond_ca_only,
'return_token_ca_only': return_token_ca_only,
'save_keys': saved_cross_attn_keys,
}
# Repeating keys leads to different weights for each key.
# assert len(set(semantic_guidance_kwargs['guidance_attn_keys'])) == len(semantic_guidance_kwargs['guidance_attn_keys']), f"guidance_attn_keys not unique: {semantic_guidance_kwargs['guidance_attn_keys']}"
for index, t in enumerate(tqdm(scheduler.timesteps, disable=not show_progress)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
if bboxes:
if use_boxdiff:
latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
else:
# If encountered None in `guidance_attn_keys`, please be sure to check whether `guidance_attn_keys` is added in `semantic_guidance_kwargs`. Default value has been removed.
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
# predict the noise residual
with torch.no_grad():
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
main_cross_attention_kwargs['save_attn_to_dict'] = {}
unet_output = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, return_cross_attention_probs=return_cross_attn, cross_attention_kwargs=main_cross_attention_kwargs)
noise_pred = unet_output.sample
if return_cross_attn:
cross_attention_probs_down.append(unet_output.cross_attention_probs_down)
cross_attention_probs_mid.append(unet_output.cross_attention_probs_mid)
cross_attention_probs_up.append(unet_output.cross_attention_probs_up)
if return_saved_cross_attn:
saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
del main_cross_attention_kwargs['save_attn_to_dict']
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if dynamic_num_inference_steps:
schedule.dynamically_adjust_inference_steps(scheduler, index, t)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
if save_all_latents:
if offload_latents_to_cpu:
latents_all.append(latents.cpu())
else:
latents_all.append(latents)
if dynamic_num_inference_steps:
# Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
scheduler.num_inference_steps = original_num_inference_steps
images = decode(vae, latents)
ret = [latents, images]
if return_cross_attn:
ret.append((cross_attention_probs_down, cross_attention_probs_mid, cross_attention_probs_up))
if return_saved_cross_attn:
ret.append(saved_attns)
if return_box_vis:
pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
ret.append(pil_images)
if save_all_latents:
latents_all = torch.stack(latents_all, dim=0)
ret.append(latents_all)
return tuple(ret)
@torch.no_grad()
def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
if not no_set_timesteps:
scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
images = decode(vae, latents)
ret = [latents, images]
return tuple(ret)
def gligen_enable_fuser(unet, enabled=True):
for module in unet.modules():
if isinstance(module, GatedSelfAttentionDense):
module.enabled = enabled
def prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt):
batch_size = len(bboxes)
assert len(phrases) == len(bboxes)
max_objs = 30
n_objs = min(max([len(bboxes_item) for bboxes_item in bboxes]), max_objs)
boxes = torch.zeros((batch_size, max_objs, 4), device=torch_device, dtype=dtype)
phrase_embeddings = torch.zeros((batch_size, max_objs, 768), device=torch_device, dtype=dtype)
# masks is a 1D tensor deciding which of the enteries to be enabled
masks = torch.zeros((batch_size, max_objs), device=torch_device, dtype=dtype)
if n_objs > 0:
for idx, (bboxes_item, phrases_item) in enumerate(zip(bboxes, phrases)):
# the length of `bboxes_item` could be smaller than `n_objs` because n_objs takes the max of item length
bboxes_item = torch.tensor(bboxes_item[:n_objs])
boxes[idx, :bboxes_item.shape[0]] = bboxes_item
tokenizer_inputs = tokenizer(phrases_item[:n_objs], padding=True, return_tensors="pt").to(torch_device)
_phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
phrase_embeddings[idx, :_phrase_embeddings.shape[0]] = _phrase_embeddings
assert bboxes_item.shape[0] == _phrase_embeddings.shape[0], f"{bboxes_item.shape[0]} != {_phrase_embeddings.shape[0]}"
masks[idx, :bboxes_item.shape[0]] = 1
# Classifier-free guidance
repeat_times = num_images_per_prompt * 2
condition_len = batch_size * repeat_times
boxes = boxes.repeat(repeat_times, 1, 1)
phrase_embeddings = phrase_embeddings.repeat(repeat_times, 1, 1)
masks = masks.repeat(repeat_times, 1)
masks[:condition_len // 2] = 0
# print("shapes:", boxes.shape, phrase_embeddings.shape, masks.shape)
return boxes, phrase_embeddings, masks, condition_len
@torch.no_grad()
def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
frozen_steps=20, frozen_mask=None,
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None,
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
"""
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
batched:
Enabled: bboxes and phrases should be a list (batch dimension) of items (specify the bboxes/phrases of each image in the batch).
Disabled: bboxes and phrases should be a list of bboxes and phrases specifying the bboxes/phrases of one image (no batch dimension).
"""
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
text_embeddings, _, cond_embeddings = process_input_embeddings(input_embeddings)
if latents.dim() == 5:
# latents_all from the input side, different from the latents_all to be saved
latents_all_input = latents
latents = latents[0]
else:
latents_all_input = None
# Just in case that we have in-place ops
latents = latents.clone()
if save_all_latents:
# offload to cpu to save space
if offload_latents_to_cpu:
latents_all = [latents.cpu()]
else:
latents_all = [latents]
scheduler.set_timesteps(num_inference_steps)
if fast_after_steps is not None:
scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
if dynamic_num_inference_steps:
original_num_inference_steps = scheduler.num_inference_steps
if frozen_mask is not None:
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
# 5.1 Prepare GLIGEN variables
if not batched_condition:
# Add batch dimension to bboxes and phrases
bboxes, phrases = [bboxes], [phrases]
boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
if semantic_guidance_bboxes and semantic_guidance:
loss = torch.tensor(10000.)
# TODO: we can also save necessary tokens only to save memory.
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
guidance_cross_attention_kwargs = {
'offload_cross_attn_to_cpu': False,
'enable_flash_attn': False,
'gligen': {
'boxes': boxes[:condition_len // 2],
'positive_embeddings': phrase_embeddings[:condition_len // 2],
'masks': masks[:condition_len // 2],
'fuser_attn_kwargs': {
'enable_flash_attn': False,
}
}
}
if return_saved_cross_attn:
saved_attns = []
main_cross_attention_kwargs = {
'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
'return_cond_ca_only': return_cond_ca_only,
'return_token_ca_only': return_token_ca_only,
'save_keys': saved_cross_attn_keys,
'gligen': {
'boxes': boxes,
'positive_embeddings': phrase_embeddings,
'masks': masks
}
}
timesteps = scheduler.timesteps
num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
gligen_enable_fuser(unet, True)
for index, t in enumerate(tqdm(timesteps, disable=not show_progress)):
# Scheduled sampling
if index == num_grounding_steps:
gligen_enable_fuser(unet, False)
if semantic_guidance_bboxes and semantic_guidance:
with torch.enable_grad():
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
main_cross_attention_kwargs['save_attn_to_dict'] = {}
# predict the noise residual
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
cross_attention_kwargs=main_cross_attention_kwargs).sample
if return_saved_cross_attn:
saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
del main_cross_attention_kwargs['save_attn_to_dict']
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if dynamic_num_inference_steps:
schedule.dynamically_adjust_inference_steps(scheduler, index, t)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
if frozen_mask is not None and index < frozen_steps:
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
# Do not save the latents in the fast steps
if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
if offload_latents_to_cpu:
latents_all.append(latents.cpu())
else:
latents_all.append(latents)
if dynamic_num_inference_steps:
# Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
scheduler.num_inference_steps = original_num_inference_steps
# Turn off fuser for typical SD
gligen_enable_fuser(unet, False)
images = decode(vae, latents)
ret = [latents, images]
if return_saved_cross_attn:
ret.append(saved_attns)
if return_box_vis:
pil_images = [utils.draw_box(Image.fromarray(image), bboxes_item, phrases_item) for image, bboxes_item, phrases_item in zip(images, bboxes, phrases)]
ret.append(pil_images)
if save_all_latents:
latents_all = torch.stack(latents_all, dim=0)
ret.append(latents_all)
return tuple(ret)
def get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
# safety for t_start overflow to prevent empty timsteps slice
if t_start == 0:
return inverse_scheduler.timesteps, num_inference_steps
timesteps = inverse_scheduler.timesteps[:-t_start]
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def invert(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5):
"""
latents: encoded from the image, should not have noise (t = 0)
returns inverted_latents for all time steps
"""
vae, tokenizer, text_encoder, unet, scheduler, inverse_scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.inverse_scheduler, model_dict.dtype
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
inverse_scheduler.set_timesteps(num_inference_steps, device=latents.device)
# We need to invert all steps because we need them to generate the background.
timesteps, num_inference_steps = get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength=1.0)
inverted_latents = [latents.cpu()]
for t in tqdm(timesteps[:-1]):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
if guidance_scale > 0.:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
latent_model_input = latents
latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)
# predict the noise residual
with torch.no_grad():
noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample
# perform guidance
noise_pred = noise_pred_uncond
# compute the previous noisy sample x_t -> x_t-1
latents = inverse_scheduler.step(noise_pred, t, latents).prev_sample
inverted_latents.append(latents.cpu())
assert len(inverted_latents) == len(timesteps)
# timestep is the first dimension
inverted_latents = torch.stack(list(reversed(inverted_latents)), dim=0)
return inverted_latents
def generate_partial_frozen(model_dict, latents_all, frozen_mask, input_embeddings, num_inference_steps, frozen_steps, guidance_scale = 7.5, bboxes=None, phrases=None, object_positions=None, semantic_guidance_kwargs=None, offload_guidance_cross_attn_to_cpu=False, use_boxdiff=False):
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
scheduler.set_timesteps(num_inference_steps)
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
latents = latents_all[0]
if bboxes:
# With semantic guidance
loss = torch.tensor(10000.)
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
guidance_cross_attention_kwargs = {
'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
# Getting invalid argument on backward, probably due to insufficient shared memory
'enable_flash_attn': False
}
for index, t in enumerate(tqdm(scheduler.timesteps)):
if bboxes:
# With semantic guidance, `guidance_attn_keys` should be in `semantic_guidance_kwargs`
if use_boxdiff:
latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
else:
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
with torch.no_grad():
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
# predict the noise residual
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
if index < frozen_steps:
latents = latents_all[index+1] * frozen_mask + latents * (1. - frozen_mask)
# scale and decode the image latents with vae
scaled_latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(scaled_latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
ret = [latents, images]
return tuple(ret)