import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Dict, Any from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.utils import logging, is_torch_version, deprecate from diffusers.utils.torch_utils import fourier_filter # UNet is a diffusers PeftAdapterMixin instance. from diffusers.loaders.peft import PeftAdapterMixin from peft import LoraConfig, get_peft_model import peft.tuners.lora as peft_lora from peft.tuners.lora.dora import DoraLinearLayer from einops import rearrange import math, re import numpy as np from peft.tuners.tuners_utils import BaseTunerLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name def dummy_func(*args, **kwargs): pass # Revised from RevGrad, by removing the grad negation. class ScaleGrad(torch.autograd.Function): @staticmethod def forward(ctx, input_, alpha_, debug=False): ctx.save_for_backward(alpha_, debug) output = input_ if debug: print(f"input: {input_.abs().mean().item()}") return output @staticmethod def backward(ctx, grad_output): # pragma: no cover # saved_tensors returns a tuple of tensors. alpha_, debug = ctx.saved_tensors if ctx.needs_input_grad[0]: grad_output2 = grad_output * alpha_ if debug: print(f"grad_output2: {grad_output2.abs().mean().item()}") else: grad_output2 = None return grad_output2, None, None class GradientScaler(nn.Module): def __init__(self, alpha=1., debug=False, *args, **kwargs): """ A gradient scaling layer. This layer has no parameters, and simply scales the gradient in the backward pass. """ super().__init__(*args, **kwargs) self._alpha = torch.tensor(alpha, requires_grad=False) self._debug = torch.tensor(debug, requires_grad=False) def forward(self, input_): _debug = self._debug if hasattr(self, '_debug') else False return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug) def gen_gradient_scaler(alpha, debug=False): if alpha == 1: return nn.Identity() if alpha > 0: return GradientScaler(alpha, debug=debug) else: assert alpha == 0 # Don't use lambda function here, otherwise the object can't be pickled. return torch.detach def split_indices_by_instance(indices, as_dict=False): indices_B, indices_N = indices unique_indices_B = torch.unique(indices_B) if not as_dict: indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ] else: indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B } return indices_by_instance # If do_sum, returned emb_attns is 3D. Otherwise 4D. # indices are applied on the first 2 dims of attn_mat. def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False): indices_by_instance = split_indices_by_instance(indices) # emb_attns[0]: [1, 9, 8, 64] # 8: 8 attention heads. Last dim 64: number of image tokens. emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ] if all_token_weights is not None: # all_token_weights: [4, 77]. # token_weights_by_instance[0]: [1, 9, 1, 1]. token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ] else: token_weights = [ 1 ] * len(indices_by_instance) # Apply token weights. emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ] # sum among K_subj_i subj embeddings -> [1, 8, 64] if do_sum: emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ] elif do_mean: emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ] emb_attns = torch.cat(emb_attns, dim=0) return emb_attns # Slow implementation equivalent to F.scaled_dot_product_attention. def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, shrink_cross_attn=False, cross_attn_shrink_factor=0.5, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: B, L, S = query.size(0), query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale # 1: head (to be broadcasted). L: query length. S: key length. attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) attn_weight = query @ key.transpose(-2, -1) * scale_factor if shrink_cross_attn: cross_attn_scale = cross_attn_shrink_factor else: cross_attn_scale = 1 # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight. attn_weight += attn_bias attn_score = attn_weight attn_weight = torch.softmax(attn_weight, dim=-1) # NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1. # But this is intended, as we want to scale down the impact of the subject embeddings # in the computed attention output tensors. attn_weight = attn_weight * cross_attn_scale attn_weight = torch.dropout(attn_weight, dropout_p, train=True) output = attn_weight @ value return output, attn_score, attn_weight # All layers share the same attention processor instance. class AttnProcessor_LoRA_Capture(nn.Module): r""" Revised from AttnProcessor2_0 """ # lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer. def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False, lora_uses_dora=True, lora_proj_layers=None, lora_rank: int = 192, lora_alpha: float = 16, cross_attn_shrink_factor: float = 0.5, q_lora_updates_query=False, attn_proc_idx=-1): super().__init__() self.global_enable_lora = enable_lora self.attn_proc_idx = attn_proc_idx # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag. # By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True. self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora) self.lora_rank = lora_rank self.lora_alpha = lora_alpha self.lora_scale = self.lora_alpha / self.lora_rank self.cross_attn_shrink_factor = cross_attn_shrink_factor self.q_lora_updates_query = q_lora_updates_query self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None if self.global_enable_lora: for lora_layer_name, lora_proj_layer in lora_proj_layers.items(): if lora_layer_name == 'q': self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, use_dora=lora_uses_dora, lora_dropout=0.1) elif lora_layer_name == 'k': self.to_k_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, use_dora=lora_uses_dora, lora_dropout=0.1) elif lora_layer_name == 'v': self.to_v_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, use_dora=lora_uses_dora, lora_dropout=0.1) elif lora_layer_name == 'out': self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, use_dora=lora_uses_dora, lora_dropout=0.1) # LoRA layers can be enabled/disabled dynamically. def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora): self.capture_ca_activations = capture_ca_activations self.shrink_cross_attn = shrink_cross_attn self.cached_activations = {} # Only enable LoRA for the next call(s) if global_enable_lora is set to True. self.enable_lora = enable_lora and self.global_enable_lora def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, img_mask: Optional[torch.Tensor] = None, subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None, debug: bool = False, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) # hidden_states: [1, 4096, 320] residual = hidden_states # attn.spatial_norm is None. if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape # Collapse the spatial dimensions to a single token dimension. hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) # NOTE: there's a inconsistency between q lora and k, v loras. # k, v loras are directly applied to key and value (currently k, v loras are never enabled), # while q lora is applied to query2, and we keep the query unchanged. if self.enable_lora and self.to_q_lora is not None: # query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate # cross attention scores between the latent images of the sc and mc instances. query2 = self.to_q_lora(hidden_states) # If not q_lora_updates_query, only query2 will be impacted by the LoRA layer. # The query, and thus the attention score and attn_out, will be the same # as the original ones. if self.q_lora_updates_query: query = query2 else: query2 = query scale = 1 / math.sqrt(query.size(-1)) is_cross_attn = (encoder_hidden_states is not None) if (not is_cross_attn) and (img_mask is not None): # NOTE: we assume the image is square. But this will fail if the image is not square. # hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64] # Scale the mask to the same size as hidden_states. mask_size = int(math.sqrt(hidden_states.shape[-2])) img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest') if (img_mask.sum(dim=(2, 3)) == 0).any(): img_mask = None else: # img_mask: [2, 1, 64, 64] -> [2, 4096] img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous() # max_neg_value = -torch.finfo(hidden_states.dtype).max # img_mask: [2, 4096] -> [2, 1, 1, 4096] img_mask = rearrange(img_mask.bool(), 'b j -> b () () j') # attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096]. # So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf). # But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine. # attn_score.masked_fill_(~img_mask, max_neg_value) # NOTE: If there's an attention mask, it will be replaced by img_mask. attention_mask = img_mask if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.enable_lora and self.to_k_lora is not None: key = self.to_k_lora(encoder_hidden_states) else: key = attn.to_k(encoder_hidden_states) if self.enable_lora and self.to_v_lora is not None: value = self.to_v_lora(encoder_hidden_states) else: value = attn.to_v(encoder_hidden_states) if attn.norm_q is not None: query = attn.norm_q(query) query2 = attn.norm_q(query2) if attn.norm_k is not None: key = attn.norm_k(key) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if debug and self.attn_proc_idx >= 0: breakpoint() # the output of sdp = (batch, num_heads, seq_len, head_dim) if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn): hidden_states, attn_score, attn_prob = \ scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn, cross_attn_shrink_factor=self.cross_attn_shrink_factor) else: # Use the faster implementation of scaled_dot_product_attention # when not capturing the activations or suppressing the subject attention. hidden_states = \ F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) attn_prob = attn_score = None hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj if self.enable_lora and self.to_out_lora is not None: hidden_states = self.to_out_lora(hidden_states) else: hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor if is_cross_attn and self.capture_ca_activations: # cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other. # So sqrt(scale) will scale the product of two qs by scale. # ANCHOR[id=attention_caching] # query: [2, 8, 4096, 40] -> [2, 320, 4096] self.cached_activations['q'] = \ rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) self.cached_activations['q2'] = \ rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) self.cached_activations['k'] = \ rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) self.cached_activations['v'] = \ rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) # attn_prob, attn_score: [2, 8, 4096, 77] self.cached_activations['attn'] = attn_prob self.cached_activations['attnscore'] = attn_score # attn_out: [b, n, h * d] -> [b, h * d, n] # [2, 4096, 320] -> [2, 320, 4096]. self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous() return hidden_states def CrossAttnUpBlock2D_forward_capture( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") self.cached_outfeats = {} res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1) capture_outfeats = getattr(self, "capture_outfeats", False) layer_idx = 0 res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # Scale down the magnitudes of gradients to res_hidden_states # by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs. res_hidden_states = res_grad_scaler(res_hidden_states) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: # resnet: ResnetBlock2D instance. #LINK diffusers.models.resnet.ResnetBlock2D # up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D, # it's not transforming the UNet shortcut features. hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if capture_outfeats: self.cached_outfeats[layer_idx] = hidden_states layer_idx += 1 if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Adapted from ConsistentIDPipeline:set_ip_adapter(). # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out']. def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'], lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5, q_lora_updates_query=False): attn_procs = {} attn_capture_procs = {} unet_modules = dict(unet.named_modules()) attn_opt_modules = {} attn_proc_idx = 0 for name, attn_proc in unet.attn_processors.items(): # Only capture the activations of the last 3 CA layers. if not name.startswith("up_blocks.3"): # Not the last 3 CA layers. Don't enable LoRA or capture activations. # Then the layer falls back to the original attention mechanism. # We still use AttnProcessor_LoRA_Capture, as it can handle img_mask. attn_procs[name] = AttnProcessor_LoRA_Capture( capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1) continue # cross_attention_dim: 768. cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if cross_attention_dim is None: # Self attention. Don't enable LoRA or capture activations. # We replace the default attn_proc with AttnProcessor_LoRA_Capture, # so that it can incorporate img_mask into self-attention. attn_procs[name] = AttnProcessor_LoRA_Capture( capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1) continue # block_id = 3 # hidden_size: 320 # hidden_size = list(reversed(unet.config.block_out_channels))[block_id] # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' -> # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q' lora_layer_dict = {} lora_layer_dict['q'] = unet_modules[name[:-9] + "to_q"] lora_layer_dict['k'] = unet_modules[name[:-9] + "to_k"] lora_layer_dict['v'] = unet_modules[name[:-9] + "to_v"] # to_out is a ModuleList(Linear, Dropout). lora_layer_dict['out'] = unet_modules[name[:-9] + "to_out"][0] lora_proj_layers = {} # Only apply LoRA to the specified layers. for lora_layer_name in attn_lora_layer_names: lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name] attn_capture_proc = AttnProcessor_LoRA_Capture( capture_ca_activations=True, enable_lora=use_attn_lora, lora_uses_dora=True, lora_proj_layers=lora_proj_layers, # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large. lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down, cross_attn_shrink_factor=cross_attn_shrink_factor, q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx) attn_proc_idx += 1 # attn_procs has to use the original names. attn_procs[name] = attn_capture_proc # ModuleDict doesn't allow "." in the key. name = name.replace(".", "_") attn_capture_procs[name] = attn_capture_proc if use_attn_lora: for subname, module in attn_capture_proc.named_modules(): if isinstance(module, peft_lora.LoraLayer): # ModuleDict doesn't allow "." in the key. lora_path = name + "_" + subname.replace(".", "_") attn_opt_modules[lora_path + "_lora_A"] = module.lora_A attn_opt_modules[lora_path + "_lora_B"] = module.lora_B # lora_uses_dora is always True, so we don't check it here. attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector # We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer, # so according to the code logic in diffusers/loaders/peft.py, # they will be managed by the diffusers PeftAdapterMixin instance, through the # enable_adapters(), and set_adapter() methods. # Therefore, we disable these calls on module. # disable_adapters() is a property and changing it will cause exceptions. module.enable_adapters = dummy_func module.set_adapter = dummy_func unet.set_attn_processor(attn_procs) print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.") print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.") return attn_capture_procs, attn_opt_modules # NOTE: cross-attn layers are included in the returned lora_modules. def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16): # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+' # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut # Cannot set to conv.+ as it will match added adapter module names, including # up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout if target_modules_pat is not None: peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=0.1, target_modules=target_modules_pat) # UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will # cause weird errors. Instead, we directly use diffusers peft adapter methods. unet.add_adapter(peft_config, "recon_loss") unet.add_adapter(peft_config, "unet_distill") unet.add_adapter(peft_config, "comp_distill") unet.enable_adapters() # lora_layers contain both the LoRA A and B matrices, as well as the original layers. # lora_layers are used to set the flag, not used for optimization. # lora_modules contain only the LoRA A and B matrices, so they are used for optimization. # NOTE: lora_modules contain both ffn and cross-attn lora modules. ffn_lora_layers = {} ffn_opt_modules = {} for name, module in unet.named_modules(): if isinstance(module, peft_lora.LoraLayer): # We don't want to include cross-attn layers in ffn_lora_layers. if target_modules_pat is not None and re.search(target_modules_pat, name): ffn_lora_layers[name] = module # ModuleDict doesn't allow "." in the key. name = name.replace(".", "_") # Since ModuleDict doesn't allow "." in the key, we manually collect # the LoRA matrices in each module. # NOTE: We cannot put every sub-module of module into lora_modules, # as base_layer is also a sub-module of module, which we shouldn't optimize. # Each value in ffn_opt_modules is a ModuleDict: ''' (Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A'] ModuleDict( (unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ''' ffn_opt_modules[name + "_lora_A"] = module.lora_A ffn_opt_modules[name + "_lora_B"] = module.lora_B if lora_uses_dora: ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.") print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.") return ffn_lora_layers, ffn_opt_modules def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs, outfeat_capture_blocks, res_hidden_states_gradscale_blocks, use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations, shrink_cross_attn, res_hidden_states_gradscale): # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags(). for attn_capture_proc in attn_capture_procs: attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora) # outfeat_capture_blocks only contains the last up block, up_blocks[3]. # It contains 3 FFN layers. We want to capture their output features. for block in outfeat_capture_blocks: block.capture_outfeats = capture_ca_activations for block in res_hidden_states_gradscale_blocks: block.res_hidden_states_gradscale = res_hidden_states_gradscale if not use_ffn_lora: unet.disable_adapters() else: # ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'. if ffn_lora_adapter_name is not None: unet.set_adapter(ffn_lora_adapter_name) # NOTE: Don't forget to enable_adapters(). # The adapters are not enabled by default after set_adapter(). unet.enable_adapters() else: breakpoint() # During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False, # which might cause issues during DDP training. # So we restore them to requires_grad=True. # During test, unet_lora_modules will be passed as None, so this block will be skipped. if unet_lora_modules is not None: for param in unet_lora_modules.parameters(): param.requires_grad = True def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks, captured_layer_indices=[22, 23, 24], out_dtype=torch.float32): captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore', 'q', 'q2', 'k', 'v', 'attn_out') } if not capture_ca_activations: return captured_activations all_cached_outfeats = [] for block in outfeat_capture_blocks: all_cached_outfeats.append(block.cached_outfeats) # Clear the capture flag and cached outfeats. block.cached_outfeats = {} block.capture_outfeats = False for layer_idx in captured_layer_indices: # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats. # 23, 24 -> 1, 2 (!! not 0, 1 !!) internal_idx = layer_idx - 22 for k in captured_activations.keys(): if k == 'outfeat': # Currently we only capture one block, up_blocks.3. So we hard-code the index 0. captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype) else: # internal_idx is the index of layers in up_blocks.3. # Layers 22, 23 and 24 map to 0, 1 and 2. cached_activations = attn_capture_procs[internal_idx].cached_activations captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype) return captured_activations