Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Optional, Tuple, Dict, Any | |
from diffusers.models.attention_processor import Attention, AttnProcessor2_0 | |
from diffusers.utils import logging, is_torch_version, deprecate | |
from diffusers.utils.torch_utils import fourier_filter | |
# UNet is a diffusers PeftAdapterMixin instance. | |
from diffusers.loaders.peft import PeftAdapterMixin | |
from peft import LoraConfig, get_peft_model | |
import peft.tuners.lora as peft_lora | |
from peft.tuners.lora.dora import DoraLinearLayer | |
from einops import rearrange | |
import math, re | |
import numpy as np | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def dummy_func(*args, **kwargs): | |
pass | |
# Revised from RevGrad, by removing the grad negation. | |
class ScaleGrad(torch.autograd.Function): | |
def forward(ctx, input_, alpha_, debug=False): | |
ctx.save_for_backward(alpha_, debug) | |
output = input_ | |
if debug: | |
print(f"input: {input_.abs().mean().item()}") | |
return output | |
def backward(ctx, grad_output): # pragma: no cover | |
# saved_tensors returns a tuple of tensors. | |
alpha_, debug = ctx.saved_tensors | |
if ctx.needs_input_grad[0]: | |
grad_output2 = grad_output * alpha_ | |
if debug: | |
print(f"grad_output2: {grad_output2.abs().mean().item()}") | |
else: | |
grad_output2 = None | |
return grad_output2, None, None | |
class GradientScaler(nn.Module): | |
def __init__(self, alpha=1., debug=False, *args, **kwargs): | |
""" | |
A gradient scaling layer. | |
This layer has no parameters, and simply scales the gradient in the backward pass. | |
""" | |
super().__init__(*args, **kwargs) | |
self._alpha = torch.tensor(alpha, requires_grad=False) | |
self._debug = torch.tensor(debug, requires_grad=False) | |
def forward(self, input_): | |
_debug = self._debug if hasattr(self, '_debug') else False | |
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug) | |
def gen_gradient_scaler(alpha, debug=False): | |
if alpha == 1: | |
return nn.Identity() | |
if alpha > 0: | |
return GradientScaler(alpha, debug=debug) | |
else: | |
assert alpha == 0 | |
# Don't use lambda function here, otherwise the object can't be pickled. | |
return torch.detach | |
def split_indices_by_instance(indices, as_dict=False): | |
indices_B, indices_N = indices | |
unique_indices_B = torch.unique(indices_B) | |
if not as_dict: | |
indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ] | |
else: | |
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B } | |
return indices_by_instance | |
# If do_sum, returned emb_attns is 3D. Otherwise 4D. | |
# indices are applied on the first 2 dims of attn_mat. | |
def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False): | |
indices_by_instance = split_indices_by_instance(indices) | |
# emb_attns[0]: [1, 9, 8, 64] | |
# 8: 8 attention heads. Last dim 64: number of image tokens. | |
emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ] | |
if all_token_weights is not None: | |
# all_token_weights: [4, 77]. | |
# token_weights_by_instance[0]: [1, 9, 1, 1]. | |
token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ] | |
else: | |
token_weights = [ 1 ] * len(indices_by_instance) | |
# Apply token weights. | |
emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ] | |
# sum among K_subj_i subj embeddings -> [1, 8, 64] | |
if do_sum: | |
emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ] | |
elif do_mean: | |
emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ] | |
emb_attns = torch.cat(emb_attns, dim=0) | |
return emb_attns | |
# Slow implementation equivalent to F.scaled_dot_product_attention. | |
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, | |
shrink_cross_attn=False, cross_attn_shrink_factor=0.5, | |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: | |
B, L, S = query.size(0), query.size(-2), key.size(-2) | |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |
# 1: head (to be broadcasted). L: query length. S: key length. | |
attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype) | |
if is_causal: | |
assert attn_mask is None | |
temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0) | |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
attn_bias.to(query.dtype) | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
else: | |
attn_bias += attn_mask | |
if enable_gqa: | |
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) | |
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) | |
attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
if shrink_cross_attn: | |
cross_attn_scale = cross_attn_shrink_factor | |
else: | |
cross_attn_scale = 1 | |
# attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight. | |
attn_weight += attn_bias | |
attn_score = attn_weight | |
attn_weight = torch.softmax(attn_weight, dim=-1) | |
# NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1. | |
# But this is intended, as we want to scale down the impact of the subject embeddings | |
# in the computed attention output tensors. | |
attn_weight = attn_weight * cross_attn_scale | |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | |
output = attn_weight @ value | |
return output, attn_score, attn_weight | |
# All layers share the same attention processor instance. | |
class AttnProcessor_LoRA_Capture(nn.Module): | |
r""" | |
Revised from AttnProcessor2_0 | |
""" | |
# lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer. | |
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False, | |
lora_uses_dora=True, lora_proj_layers=None, | |
lora_rank: int = 192, lora_alpha: float = 16, | |
cross_attn_shrink_factor: float = 0.5, | |
q_lora_updates_query=False, attn_proc_idx=-1): | |
super().__init__() | |
self.global_enable_lora = enable_lora | |
self.attn_proc_idx = attn_proc_idx | |
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag. | |
# By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True. | |
self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora) | |
self.lora_rank = lora_rank | |
self.lora_alpha = lora_alpha | |
self.lora_scale = self.lora_alpha / self.lora_rank | |
self.cross_attn_shrink_factor = cross_attn_shrink_factor | |
self.q_lora_updates_query = q_lora_updates_query | |
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None | |
if self.global_enable_lora: | |
for lora_layer_name, lora_proj_layer in lora_proj_layers.items(): | |
if lora_layer_name == 'q': | |
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, | |
use_dora=lora_uses_dora, lora_dropout=0.1) | |
elif lora_layer_name == 'k': | |
self.to_k_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, | |
use_dora=lora_uses_dora, lora_dropout=0.1) | |
elif lora_layer_name == 'v': | |
self.to_v_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, | |
use_dora=lora_uses_dora, lora_dropout=0.1) | |
elif lora_layer_name == 'out': | |
self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, | |
use_dora=lora_uses_dora, lora_dropout=0.1) | |
# LoRA layers can be enabled/disabled dynamically. | |
def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora): | |
self.capture_ca_activations = capture_ca_activations | |
self.shrink_cross_attn = shrink_cross_attn | |
self.cached_activations = {} | |
# Only enable LoRA for the next call(s) if global_enable_lora is set to True. | |
self.enable_lora = enable_lora and self.global_enable_lora | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
img_mask: Optional[torch.Tensor] = None, | |
subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None, | |
debug: bool = False, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
if len(args) > 0 or kwargs.get("scale", None) is not None: | |
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | |
deprecate("scale", "1.0.0", deprecation_message) | |
# hidden_states: [1, 4096, 320] | |
residual = hidden_states | |
# attn.spatial_norm is None. | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
# Collapse the spatial dimensions to a single token dimension. | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
# NOTE: there's a inconsistency between q lora and k, v loras. | |
# k, v loras are directly applied to key and value (currently k, v loras are never enabled), | |
# while q lora is applied to query2, and we keep the query unchanged. | |
if self.enable_lora and self.to_q_lora is not None: | |
# query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate | |
# cross attention scores between the latent images of the sc and mc instances. | |
query2 = self.to_q_lora(hidden_states) | |
# If not q_lora_updates_query, only query2 will be impacted by the LoRA layer. | |
# The query, and thus the attention score and attn_out, will be the same | |
# as the original ones. | |
if self.q_lora_updates_query: | |
query = query2 | |
else: | |
query2 = query | |
scale = 1 / math.sqrt(query.size(-1)) | |
is_cross_attn = (encoder_hidden_states is not None) | |
if (not is_cross_attn) and (img_mask is not None): | |
# NOTE: we assume the image is square. But this will fail if the image is not square. | |
# hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64] | |
# Scale the mask to the same size as hidden_states. | |
mask_size = int(math.sqrt(hidden_states.shape[-2])) | |
img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest') | |
if (img_mask.sum(dim=(2, 3)) == 0).any(): | |
img_mask = None | |
else: | |
# img_mask: [2, 1, 64, 64] -> [2, 4096] | |
img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous() | |
# max_neg_value = -torch.finfo(hidden_states.dtype).max | |
# img_mask: [2, 4096] -> [2, 1, 1, 4096] | |
img_mask = rearrange(img_mask.bool(), 'b j -> b () () j') | |
# attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096]. | |
# So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf). | |
# But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine. | |
# attn_score.masked_fill_(~img_mask, max_neg_value) | |
# NOTE: If there's an attention mask, it will be replaced by img_mask. | |
attention_mask = img_mask | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
if self.enable_lora and self.to_k_lora is not None: | |
key = self.to_k_lora(encoder_hidden_states) | |
else: | |
key = attn.to_k(encoder_hidden_states) | |
if self.enable_lora and self.to_v_lora is not None: | |
value = self.to_v_lora(encoder_hidden_states) | |
else: | |
value = attn.to_v(encoder_hidden_states) | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
query2 = attn.norm_q(query2) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
if debug and self.attn_proc_idx >= 0: | |
breakpoint() | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn): | |
hidden_states, attn_score, attn_prob = \ | |
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, | |
dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn, | |
cross_attn_shrink_factor=self.cross_attn_shrink_factor) | |
else: | |
# Use the faster implementation of scaled_dot_product_attention | |
# when not capturing the activations or suppressing the subject attention. | |
hidden_states = \ | |
F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) | |
attn_prob = attn_score = None | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
if self.enable_lora and self.to_out_lora is not None: | |
hidden_states = self.to_out_lora(hidden_states) | |
else: | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
if is_cross_attn and self.capture_ca_activations: | |
# cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other. | |
# So sqrt(scale) will scale the product of two qs by scale. | |
# ANCHOR[id=attention_caching] | |
# query: [2, 8, 4096, 40] -> [2, 320, 4096] | |
self.cached_activations['q'] = \ | |
rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) | |
self.cached_activations['q2'] = \ | |
rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) | |
self.cached_activations['k'] = \ | |
rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) | |
self.cached_activations['v'] = \ | |
rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale) | |
# attn_prob, attn_score: [2, 8, 4096, 77] | |
self.cached_activations['attn'] = attn_prob | |
self.cached_activations['attnscore'] = attn_score | |
# attn_out: [b, n, h * d] -> [b, h * d, n] | |
# [2, 4096, 320] -> [2, 320, 4096]. | |
self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous() | |
return hidden_states | |
def CrossAttnUpBlock2D_forward_capture( | |
self, | |
hidden_states: torch.Tensor, | |
res_hidden_states_tuple: Tuple[torch.Tensor, ...], | |
temb: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
upsample_size: Optional[int] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
self.cached_outfeats = {} | |
res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1) | |
capture_outfeats = getattr(self, "capture_outfeats", False) | |
layer_idx = 0 | |
res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale) | |
for resnet, attn in zip(self.resnets, self.attentions): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
# Scale down the magnitudes of gradients to res_hidden_states | |
# by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs. | |
res_hidden_states = res_grad_scaler(res_hidden_states) | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
else: | |
# resnet: ResnetBlock2D instance. | |
#LINK diffusers.models.resnet.ResnetBlock2D | |
# up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D, | |
# it's not transforming the UNet shortcut features. | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
if capture_outfeats: | |
self.cached_outfeats[layer_idx] = hidden_states | |
layer_idx += 1 | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
# Adapted from ConsistentIDPipeline:set_ip_adapter(). | |
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out']. | |
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'], | |
lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5, | |
q_lora_updates_query=False): | |
attn_procs = {} | |
attn_capture_procs = {} | |
unet_modules = dict(unet.named_modules()) | |
attn_opt_modules = {} | |
attn_proc_idx = 0 | |
for name, attn_proc in unet.attn_processors.items(): | |
# Only capture the activations of the last 3 CA layers. | |
if not name.startswith("up_blocks.3"): | |
# Not the last 3 CA layers. Don't enable LoRA or capture activations. | |
# Then the layer falls back to the original attention mechanism. | |
# We still use AttnProcessor_LoRA_Capture, as it can handle img_mask. | |
attn_procs[name] = AttnProcessor_LoRA_Capture( | |
capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1) | |
continue | |
# cross_attention_dim: 768. | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if cross_attention_dim is None: | |
# Self attention. Don't enable LoRA or capture activations. | |
# We replace the default attn_proc with AttnProcessor_LoRA_Capture, | |
# so that it can incorporate img_mask into self-attention. | |
attn_procs[name] = AttnProcessor_LoRA_Capture( | |
capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1) | |
continue | |
# block_id = 3 | |
# hidden_size: 320 | |
# hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' -> | |
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q' | |
lora_layer_dict = {} | |
lora_layer_dict['q'] = unet_modules[name[:-9] + "to_q"] | |
lora_layer_dict['k'] = unet_modules[name[:-9] + "to_k"] | |
lora_layer_dict['v'] = unet_modules[name[:-9] + "to_v"] | |
# to_out is a ModuleList(Linear, Dropout). | |
lora_layer_dict['out'] = unet_modules[name[:-9] + "to_out"][0] | |
lora_proj_layers = {} | |
# Only apply LoRA to the specified layers. | |
for lora_layer_name in attn_lora_layer_names: | |
lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name] | |
attn_capture_proc = AttnProcessor_LoRA_Capture( | |
capture_ca_activations=True, enable_lora=use_attn_lora, | |
lora_uses_dora=True, lora_proj_layers=lora_proj_layers, | |
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large. | |
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down, | |
cross_attn_shrink_factor=cross_attn_shrink_factor, | |
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx) | |
attn_proc_idx += 1 | |
# attn_procs has to use the original names. | |
attn_procs[name] = attn_capture_proc | |
# ModuleDict doesn't allow "." in the key. | |
name = name.replace(".", "_") | |
attn_capture_procs[name] = attn_capture_proc | |
if use_attn_lora: | |
for subname, module in attn_capture_proc.named_modules(): | |
if isinstance(module, peft_lora.LoraLayer): | |
# ModuleDict doesn't allow "." in the key. | |
lora_path = name + "_" + subname.replace(".", "_") | |
attn_opt_modules[lora_path + "_lora_A"] = module.lora_A | |
attn_opt_modules[lora_path + "_lora_B"] = module.lora_B | |
# lora_uses_dora is always True, so we don't check it here. | |
attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector | |
# We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer, | |
# so according to the code logic in diffusers/loaders/peft.py, | |
# they will be managed by the diffusers PeftAdapterMixin instance, through the | |
# enable_adapters(), and set_adapter() methods. | |
# Therefore, we disable these calls on module. | |
# disable_adapters() is a property and changing it will cause exceptions. | |
module.enable_adapters = dummy_func | |
module.set_adapter = dummy_func | |
unet.set_attn_processor(attn_procs) | |
print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.") | |
print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.") | |
return attn_capture_procs, attn_opt_modules | |
# NOTE: cross-attn layers are included in the returned lora_modules. | |
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16): | |
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+' | |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut | |
# Cannot set to conv.+ as it will match added adapter module names, including | |
# up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout | |
if target_modules_pat is not None: | |
peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank, | |
lora_alpha=lora_alpha, lora_dropout=0.1, | |
target_modules=target_modules_pat) | |
# UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will | |
# cause weird errors. Instead, we directly use diffusers peft adapter methods. | |
unet.add_adapter(peft_config, "recon_loss") | |
unet.add_adapter(peft_config, "unet_distill") | |
unet.add_adapter(peft_config, "comp_distill") | |
unet.enable_adapters() | |
# lora_layers contain both the LoRA A and B matrices, as well as the original layers. | |
# lora_layers are used to set the flag, not used for optimization. | |
# lora_modules contain only the LoRA A and B matrices, so they are used for optimization. | |
# NOTE: lora_modules contain both ffn and cross-attn lora modules. | |
ffn_lora_layers = {} | |
ffn_opt_modules = {} | |
for name, module in unet.named_modules(): | |
if isinstance(module, peft_lora.LoraLayer): | |
# We don't want to include cross-attn layers in ffn_lora_layers. | |
if target_modules_pat is not None and re.search(target_modules_pat, name): | |
ffn_lora_layers[name] = module | |
# ModuleDict doesn't allow "." in the key. | |
name = name.replace(".", "_") | |
# Since ModuleDict doesn't allow "." in the key, we manually collect | |
# the LoRA matrices in each module. | |
# NOTE: We cannot put every sub-module of module into lora_modules, | |
# as base_layer is also a sub-module of module, which we shouldn't optimize. | |
# Each value in ffn_opt_modules is a ModuleDict: | |
''' | |
(Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A'] | |
ModuleDict( | |
(unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
(recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
) | |
''' | |
ffn_opt_modules[name + "_lora_A"] = module.lora_A | |
ffn_opt_modules[name + "_lora_B"] = module.lora_B | |
if lora_uses_dora: | |
ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector | |
print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.") | |
print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.") | |
return ffn_lora_layers, ffn_opt_modules | |
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs, | |
outfeat_capture_blocks, res_hidden_states_gradscale_blocks, | |
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations, | |
shrink_cross_attn, res_hidden_states_gradscale): | |
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags(). | |
for attn_capture_proc in attn_capture_procs: | |
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora) | |
# outfeat_capture_blocks only contains the last up block, up_blocks[3]. | |
# It contains 3 FFN layers. We want to capture their output features. | |
for block in outfeat_capture_blocks: | |
block.capture_outfeats = capture_ca_activations | |
for block in res_hidden_states_gradscale_blocks: | |
block.res_hidden_states_gradscale = res_hidden_states_gradscale | |
if not use_ffn_lora: | |
unet.disable_adapters() | |
else: | |
# ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'. | |
if ffn_lora_adapter_name is not None: | |
unet.set_adapter(ffn_lora_adapter_name) | |
# NOTE: Don't forget to enable_adapters(). | |
# The adapters are not enabled by default after set_adapter(). | |
unet.enable_adapters() | |
else: | |
breakpoint() | |
# During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False, | |
# which might cause issues during DDP training. | |
# So we restore them to requires_grad=True. | |
# During test, unet_lora_modules will be passed as None, so this block will be skipped. | |
if unet_lora_modules is not None: | |
for param in unet_lora_modules.parameters(): | |
param.requires_grad = True | |
def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks, | |
captured_layer_indices=[22, 23, 24], out_dtype=torch.float32): | |
captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore', | |
'q', 'q2', 'k', 'v', 'attn_out') } | |
if not capture_ca_activations: | |
return captured_activations | |
all_cached_outfeats = [] | |
for block in outfeat_capture_blocks: | |
all_cached_outfeats.append(block.cached_outfeats) | |
# Clear the capture flag and cached outfeats. | |
block.cached_outfeats = {} | |
block.capture_outfeats = False | |
for layer_idx in captured_layer_indices: | |
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats. | |
# 23, 24 -> 1, 2 (!! not 0, 1 !!) | |
internal_idx = layer_idx - 22 | |
for k in captured_activations.keys(): | |
if k == 'outfeat': | |
# Currently we only capture one block, up_blocks.3. So we hard-code the index 0. | |
captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype) | |
else: | |
# internal_idx is the index of layers in up_blocks.3. | |
# Layers 22, 23 and 24 map to 0, 1 and 2. | |
cached_activations = attn_capture_procs[internal_idx].cached_activations | |
captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype) | |
return captured_activations | |