import os import torch import collections import torch.nn as nn from functools import partial from transformers import CLIPTextModel, CLIPTokenizer, logging from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler from models.unet_2d_condition import UNet2DConditionModel from utils.attention_utils import CrossAttentionLayers, SelfAttentionLayers # suppress partial model loading warning logging.set_verbosity_error() class RegionDiffusion(nn.Module): def __init__(self, device): super().__init__() self.device = device self.num_train_timesteps = 1000 self.clip_gradient = False print(f'[INFO] loading stable diffusion...') model_id = 'runwayml/stable-diffusion-v1-5' self.vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae").to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained( model_id, subfolder='tokenizer') self.text_encoder = CLIPTextModel.from_pretrained( model_id, subfolder='text_encoder').to(self.device) self.unet = UNet2DConditionModel.from_pretrained( model_id, subfolder="unet").to(self.device) self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=self.num_train_timesteps, skip_prk_steps=True, steps_offset=1) self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device) self.masks = [] self.attention_maps = None self.selfattn_maps = None self.crossattn_maps = None self.color_loss = torch.nn.functional.mse_loss self.forward_replacement_hooks = [] print(f'[INFO] loaded stable diffusion!') def get_text_embeds(self, prompt, negative_prompt): # prompt, negative_prompt: [str] # Tokenize text and get embeddings text_input = self.tokenizer( prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') with torch.no_grad(): text_embeddings = self.text_encoder( text_input.input_ids.to(self.device))[0] # Do the same for unconditional embeddings uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') with torch.no_grad(): uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(self.device))[0] # Cat for final embeddings text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings def get_text_embeds_list(self, prompts): # prompts: [list] text_embeddings = [] for prompt in prompts: # Tokenize text and get embeddings text_input = self.tokenizer( [prompt], padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') with torch.no_grad(): text_embeddings.append(self.text_encoder( text_input.input_ids.to(self.device))[0]) return text_embeddings def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, bg_aug_end=1000): if latents is None: latents = torch.randn( (1, self.unet.in_channels, height // 8, width // 8), device=self.device) if inject_selfattn > 0: latents_reference = latents.clone().detach() self.scheduler.set_timesteps(num_inference_steps) n_styles = text_embeddings.shape[0]-1 assert n_styles == len(self.masks) with torch.autocast('cuda'): for i, t in enumerate(self.scheduler.timesteps): # predict the noise residual with torch.no_grad(): # tokens without any attributes feat_inject_step = t > (1-inject_selfattn) * 1000 noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1], text_format_dict={})['sample'] noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:], text_format_dict=text_format_dict)['sample'] if inject_selfattn > 0: noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1], text_format_dict={})['sample'] self.register_selfattn_hooks(feat_inject_step) noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:], text_format_dict={})['sample'] self.remove_selfattn_hooks() noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1] noise_pred_text = noise_pred_text_cur * self.masks[-1] # tokens with attributes for style_i, mask in enumerate(self.masks[:-1]): if t > bg_aug_end: rand_rgb = torch.rand([1, 3, 1, 1]).cuda() black_background = torch.ones( [1, 3, height, width]).cuda()*rand_rgb black_latent = self.encode_imgs( black_background) noise = torch.randn_like(black_latent) black_latent_noisy = self.scheduler.add_noise( black_latent, noise, t) masked_latent = ( mask > 0.001) * latents + (mask < 0.001) * black_latent_noisy noise_pred_uncond_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[:1], text_format_dict={})['sample'] else: masked_latent = latents self.register_replacement_hooks(feat_inject_step) noise_pred_text_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2], text_format_dict={})['sample'] self.remove_replacement_hooks() noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask noise_pred_text = noise_pred_text + noise_pred_text_cur*mask # perform classifier-free guidance noise_pred = noise_pred_uncond + guidance_scale * \ (noise_pred_text - noise_pred_uncond) if inject_selfattn > 0: noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \ (noise_pred_text_refer - noise_pred_uncond_refer) # compute the previous noisy sample x_t -> x_t-1 latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t, torch.cat([latents, latents_reference]))[ 'prev_sample'] latents, latents_reference = torch.chunk( latents_reference, 2, dim=0) else: # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents)[ 'prev_sample'] # apply guidance if use_guidance and t < text_format_dict['guidance_start_step']: with torch.enable_grad(): if not latents.requires_grad: latents.requires_grad = True latents_0 = self.predict_x0(latents, noise_pred, t) latents_inp = 1 / 0.18215 * latents_0 imgs = self.vae.decode(latents_inp).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) loss_total = 0. for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']): avg_rgb = ( imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum() loss = self.color_loss( avg_rgb, rgb_val[:, :, 0, 0])*100 # print(loss) loss_total += loss loss_total.backward() latents = ( latents - latents.grad * text_format_dict['color_guidance_weight'] * self.masks[0]).detach().clone() return latents def predict_x0(self, x_t, eps_t, t): alpha_t = self.scheduler.alphas_cumprod[t] return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t) def produce_attn_maps(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds text_embeddings = self.get_text_embeds( prompts, negative_prompts) # [2, 77, 768] if latents is None: latents = torch.randn( (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device) self.scheduler.set_timesteps(num_inference_steps) self.remove_replacement_hooks() with torch.autocast('cuda'): for i, t in enumerate(self.scheduler.timesteps): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) # predict the noise residual with torch.no_grad(): noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * \ (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents)[ 'prev_sample'] # Img latents -> imgs imgs = self.decode_latents(latents) # [1, 3, 512, 512] # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype('uint8') return imgs def decode_latents(self, latents): latents = 1 / 0.18215 * latents with torch.no_grad(): imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist latents = posterior.sample() * 0.18215 return latents def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, bg_aug_end=1000): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds text_embeds = self.get_text_embeds( prompts, negative_prompts) # [2, 77, 768] # else: latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_guidance=use_guidance, text_format_dict=text_format_dict, inject_selfattn=inject_selfattn, bg_aug_end=bg_aug_end) # [1, 4, 64, 64] # Img latents -> imgs imgs = self.decode_latents(latents) # [1, 3, 512, 512] # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype('uint8') return imgs def reset_attention_maps(self): r"""Function to reset attention maps. We reset attention maps because we append them while getting hooks to visualize attention maps for every step. """ for key in self.selfattn_maps: self.selfattn_maps[key] = [] for key in self.crossattn_maps: self.crossattn_maps[key] = [] def register_evaluation_hooks(self): r"""Function for registering hooks during evaluation. We mainly store activation maps averaged over queries. """ self.forward_hooks = [] def save_activations(activations, name, module, inp, out): r""" PyTorch Forward hook to save outputs at each forward pass. """ # out[0] - final output of attention layer # out[1] - attention probability matrix if 'attn2' in name: assert out[1].shape[-1] == 77 activations[name].append(out[1].detach().cpu()) else: assert out[1].shape[-1] != 77 attention_dict = collections.defaultdict(list) for name, module in self.unet.named_modules(): leaf_name = name.split('.')[-1] if 'attn' in leaf_name: # Register hook to obtain outputs at every attention layer. self.forward_hooks.append(module.register_forward_hook( partial(save_activations, attention_dict, name) )) # attention_dict is a dictionary containing attention maps for every attention layer self.attention_maps = attention_dict def register_selfattn_hooks(self, feat_inject_step=False): r"""Function for registering hooks during evaluation. We mainly store activation maps averaged over queries. """ self.selfattn_forward_hooks = [] def save_activations(activations, name, module, inp, out): r""" PyTorch Forward hook to save outputs at each forward pass. """ # out[0] - final output of attention layer # out[1] - attention probability matrix if 'attn2' in name: assert out[1][1].shape[-1] == 77 # cross attention injection # activations[name] = out[1][1].detach() else: assert out[1][1].shape[-1] != 77 activations[name] = out[1][1].detach() def save_resnet_activations(activations, name, module, inp, out): r""" PyTorch Forward hook to save outputs at each forward pass. """ # out[0] - final output of residual layer # out[1] - residual hidden feature # import ipdb # ipdb.set_trace() assert out[1].shape[-1] == 16 activations[name] = out[1].detach() attention_dict = collections.defaultdict(list) for name, module in self.unet.named_modules(): leaf_name = name.split('.')[-1] if 'attn' in leaf_name and feat_inject_step: # Register hook to obtain outputs at every attention layer. self.selfattn_forward_hooks.append(module.register_forward_hook( partial(save_activations, attention_dict, name) )) if name == 'up_blocks.1.resnets.1' and feat_inject_step: self.selfattn_forward_hooks.append(module.register_forward_hook( partial(save_resnet_activations, attention_dict, name) )) # attention_dict is a dictionary containing attention maps for every attention layer self.self_attention_maps_cur = attention_dict def register_replacement_hooks(self, feat_inject_step=False): r"""Function for registering hooks to replace self attention. """ self.forward_replacement_hooks = [] def replace_activations(name, module, args): r""" PyTorch Forward hook to save outputs at each forward pass. """ if 'attn1' in name: modified_args = (args[0], self.self_attention_maps_cur[name]) return modified_args # cross attention injection # elif 'attn2' in name: # modified_map = { # 'reference': self.self_attention_maps_cur[name], # 'inject_pos': self.inject_pos, # } # modified_args = (args[0], modified_map) # return modified_args def replace_resnet_activations(name, module, args): r""" PyTorch Forward hook to save outputs at each forward pass. """ modified_args = (args[0], args[1], self.self_attention_maps_cur[name]) return modified_args for name, module in self.unet.named_modules(): leaf_name = name.split('.')[-1] if 'attn' in leaf_name and feat_inject_step: # Register hook to obtain outputs at every attention layer. self.forward_replacement_hooks.append(module.register_forward_pre_hook( partial(replace_activations, name) )) if name == 'up_blocks.1.resnets.1' and feat_inject_step: # Register hook to obtain outputs at every attention layer. self.forward_replacement_hooks.append(module.register_forward_pre_hook( partial(replace_resnet_activations, name) )) def register_tokenmap_hooks(self): r"""Function for registering hooks during evaluation. We mainly store activation maps averaged over queries. """ self.forward_hooks = [] def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out): r""" PyTorch Forward hook to save outputs at each forward pass. """ # out[0] - final output of attention layer # out[1] - attention probability matrices if name in n_maps: n_maps[name] += 1 else: n_maps[name] = 1 if 'attn2' in name: assert out[1][0].shape[-1] == 77 if name in CrossAttentionLayers and n_maps[name] > 10: if name in crossattn_maps: crossattn_maps[name] += out[1][0].detach().cpu()[1:2] else: crossattn_maps[name] = out[1][0].detach().cpu()[1:2] else: assert out[1][0].shape[-1] != 77 if name in SelfAttentionLayers and n_maps[name] > 10: if name in crossattn_maps: selfattn_maps[name] += out[1][0].detach().cpu()[1:2] else: selfattn_maps[name] = out[1][0].detach().cpu()[1:2] selfattn_maps = collections.defaultdict(list) crossattn_maps = collections.defaultdict(list) n_maps = collections.defaultdict(list) for name, module in self.unet.named_modules(): leaf_name = name.split('.')[-1] if 'attn' in leaf_name: # Register hook to obtain outputs at every attention layer. self.forward_hooks.append(module.register_forward_hook( partial(save_activations, selfattn_maps, crossattn_maps, n_maps, name) )) # attention_dict is a dictionary containing attention maps for every attention layer self.selfattn_maps = selfattn_maps self.crossattn_maps = crossattn_maps self.n_maps = n_maps def remove_tokenmap_hooks(self): for hook in self.forward_hooks: hook.remove() self.selfattn_maps = None self.crossattn_maps = None self.n_maps = None def remove_evaluation_hooks(self): for hook in self.forward_hooks: hook.remove() self.attention_maps = None def remove_replacement_hooks(self): for hook in self.forward_replacement_hooks: hook.remove() def remove_selfattn_hooks(self): for hook in self.selfattn_forward_hooks: hook.remove()