|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from diffusers.models.attention_processor import AttnProcessor |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
if is_xformers_available(): |
|
import xformers |
|
|
|
|
|
def remove_edlora_unet_attention_forward(unet): |
|
def change_forward(unet): |
|
for name, layer in unet.named_children(): |
|
if layer.__class__.__name__ == 'Attention' and name == 'attn2': |
|
layer.set_processor(AttnProcessor()) |
|
else: |
|
change_forward(layer) |
|
change_forward(unet) |
|
|
|
|
|
class EDLoRA_Control_AttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
def __init__(self, cross_attention_idx, place_in_unet, controller, attention_op=None): |
|
self.cross_attention_idx = cross_attention_idx |
|
self.place_in_unet = place_in_unet |
|
self.controller = controller |
|
self.attention_op = attention_op |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
if encoder_hidden_states is None: |
|
is_cross = False |
|
encoder_hidden_states = hidden_states |
|
else: |
|
is_cross = True |
|
if len(encoder_hidden_states.shape) == 4: |
|
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...] |
|
else: |
|
encoder_hidden_states = encoder_hidden_states |
|
|
|
assert not attn.norm_cross |
|
|
|
batch_size, sequence_length, _ = encoder_hidden_states.shape |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
if is_xformers_available() and not is_cross: |
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = hidden_states.to(query.dtype) |
|
else: |
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class EDLoRA_AttnProcessor: |
|
def __init__(self, cross_attention_idx, attention_op=None): |
|
self.attention_op = attention_op |
|
self.cross_attention_idx = cross_attention_idx |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
if len(encoder_hidden_states.shape) == 4: |
|
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...] |
|
else: |
|
encoder_hidden_states = encoder_hidden_states |
|
|
|
assert not attn.norm_cross |
|
|
|
batch_size, sequence_length, _ = encoder_hidden_states.shape |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
if is_xformers_available(): |
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = hidden_states.to(query.dtype) |
|
else: |
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
def revise_edlora_unet_attention_forward(unet): |
|
def change_forward(unet, count): |
|
for name, layer in unet.named_children(): |
|
if layer.__class__.__name__ == 'Attention' and 'attn2' in name: |
|
layer.set_processor(EDLoRA_AttnProcessor(count)) |
|
count += 1 |
|
else: |
|
count = change_forward(layer, count) |
|
return count |
|
|
|
|
|
cross_attention_idx = change_forward(unet.down_blocks, 0) |
|
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx) |
|
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx) |
|
print(f'Number of attention layer registered {cross_attention_idx}') |
|
|
|
|
|
def revise_edlora_unet_attention_controller_forward(unet, controller): |
|
class DummyController: |
|
def __call__(self, *args): |
|
return args[0] |
|
|
|
def __init__(self): |
|
self.num_att_layers = 0 |
|
|
|
if controller is None: |
|
controller = DummyController() |
|
|
|
def change_forward(unet, count, place_in_unet): |
|
for name, layer in unet.named_children(): |
|
if layer.__class__.__name__ == 'Attention' and 'attn2' in name: |
|
layer.set_processor(EDLoRA_Control_AttnProcessor(count, place_in_unet, controller)) |
|
count += 1 |
|
else: |
|
count = change_forward(layer, count, place_in_unet) |
|
return count |
|
|
|
|
|
cross_attention_idx = change_forward(unet.down_blocks, 0, 'down') |
|
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, 'mid') |
|
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, 'up') |
|
print(f'Number of attention layer registered {cross_attention_idx}') |
|
controller.num_att_layers = cross_attention_idx |
|
|
|
|
|
class LoRALinearLayer(nn.Module): |
|
def __init__(self, name, original_module, rank=4, alpha=1): |
|
super().__init__() |
|
|
|
self.name = name |
|
|
|
|
|
rank = 32 |
|
|
|
if original_module.__class__.__name__ == 'Conv2d': |
|
in_channels, out_channels = original_module.in_channels, original_module.out_channels |
|
self.lora_down = torch.nn.Conv2d(in_channels, rank, (1, 1), bias=False) |
|
self.lora_up = torch.nn.Conv2d(rank, out_channels, (1, 1), bias=False) |
|
else: |
|
in_features, out_features = original_module.in_features, original_module.out_features |
|
self.lora_down = nn.Linear(in_features, rank, bias=False) |
|
self.lora_up = nn.Linear(rank, out_features, bias=False) |
|
|
|
self.register_buffer('alpha', torch.tensor(alpha)) |
|
|
|
|
|
m = np.load(f"orthogonal_mats/{in_features}.npy") |
|
idxs = np.random.choice(in_features, size = rank, replace = False) |
|
m = m[idxs]/2 |
|
with torch.no_grad(): |
|
self.lora_down.weight = torch.nn.Parameter(torch.tensor(m, dtype = self.lora_down.weight.dtype)) |
|
|
|
torch.nn.init.zeros_(self.lora_up.weight) |
|
|
|
for param in self.lora_down.parameters(): |
|
param.requires_grad = False |
|
|
|
self.original_forward = original_module.forward |
|
original_module.forward = self.forward |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.original_forward(hidden_states) + self.alpha * self.lora_up(self.lora_down(hidden_states)) |
|
return hidden_states |
|
|