import math |
import numpy as np |
import torch |
import torch.nn as nn |
from diffusers.models.attention_processor import AttnProcessor |
from diffusers.utils.import_utils import is_xformers_available |
if is_xformers_available(): |
import xformers |
def remove_edlora_unet_attention_forward(unet): |
def change_forward(unet): |
for name, layer in unet.named_children(): |
if layer.__class__.__name__ == 'Attention' and name == 'attn2': |
layer.set_processor(AttnProcessor()) |
else: |
change_forward(layer) |
change_forward(unet) |
class EDLoRA_Control_AttnProcessor: |
r""" |
Default processor for performing attention-related computations. |
""" |
def __init__(self, cross_attention_idx, place_in_unet, controller, attention_op=None): |
self.cross_attention_idx = cross_attention_idx |
self.place_in_unet = place_in_unet |
self.controller = controller |
self.attention_op = attention_op |
def __call__( |
self, |
attn, |
hidden_states, |
encoder_hidden_states=None, |
attention_mask=None, |
temb=None, |
): |
residual = hidden_states |
if attn.spatial_norm is not None: |
hidden_states = attn.spatial_norm(hidden_states, temb) |
input_ndim = hidden_states.ndim |
if input_ndim == 4: |
batch_size, channel, height, width = hidden_states.shape |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
if encoder_hidden_states is None: |
is_cross = False |
encoder_hidden_states = hidden_states |
else: |
is_cross = True |
if len(encoder_hidden_states.shape) == 4: |
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...] |
else: |
encoder_hidden_states = encoder_hidden_states |
assert not attn.norm_cross |
batch_size, sequence_length, _ = encoder_hidden_states.shape |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
if attn.group_norm is not None: |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
query = attn.to_q(hidden_states) |
key = attn.to_k(encoder_hidden_states) |
value = attn.to_v(encoder_hidden_states) |
query = attn.head_to_batch_dim(query).contiguous() |
key = attn.head_to_batch_dim(key).contiguous() |
value = attn.head_to_batch_dim(value).contiguous() |
if is_xformers_available() and not is_cross: |
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
hidden_states = hidden_states.to(query.dtype) |
else: |
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet) |
hidden_states = torch.bmm(attention_probs, value) |
hidden_states = attn.batch_to_head_dim(hidden_states) |
hidden_states = attn.to_out[0](hidden_states) |
hidden_states = attn.to_out[1](hidden_states) |
if input_ndim == 4: |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
if attn.residual_connection: |
hidden_states = hidden_states + residual |
hidden_states = hidden_states / attn.rescale_output_factor |
return hidden_states |
class EDLoRA_AttnProcessor: |
def __init__(self, cross_attention_idx, attention_op=None): |
self.attention_op = attention_op |
self.cross_attention_idx = cross_attention_idx |
def __call__( |
self, |
attn, |
hidden_states, |
encoder_hidden_states=None, |
attention_mask=None, |
temb=None, |
): |
residual = hidden_states |
if attn.spatial_norm is not None: |
hidden_states = attn.spatial_norm(hidden_states, temb) |
input_ndim = hidden_states.ndim |
if input_ndim == 4: |
batch_size, channel, height, width = hidden_states.shape |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
if encoder_hidden_states is None: |
encoder_hidden_states = hidden_states |
else: |
if len(encoder_hidden_states.shape) == 4: |
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...] |
else: |
encoder_hidden_states = encoder_hidden_states |
assert not attn.norm_cross |
batch_size, sequence_length, _ = encoder_hidden_states.shape |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
if attn.group_norm is not None: |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
query = attn.to_q(hidden_states) |
key = attn.to_k(encoder_hidden_states) |
value = attn.to_v(encoder_hidden_states) |
query = attn.head_to_batch_dim(query).contiguous() |
key = attn.head_to_batch_dim(key).contiguous() |
value = attn.head_to_batch_dim(value).contiguous() |
if is_xformers_available(): |
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
hidden_states = hidden_states.to(query.dtype) |
else: |
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
hidden_states = torch.bmm(attention_probs, value) |
hidden_states = attn.batch_to_head_dim(hidden_states) |
hidden_states = attn.to_out[0](hidden_states) |
hidden_states = attn.to_out[1](hidden_states) |
if input_ndim == 4: |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
if attn.residual_connection: |
hidden_states = hidden_states + residual |
hidden_states = hidden_states / attn.rescale_output_factor |
return hidden_states |
def revise_edlora_unet_attention_forward(unet): |
def change_forward(unet, count): |
for name, layer in unet.named_children(): |
if layer.__class__.__name__ == 'Attention' and 'attn2' in name: |
layer.set_processor(EDLoRA_AttnProcessor(count)) |
count += 1 |
else: |
count = change_forward(layer, count) |
return count |
cross_attention_idx = change_forward(unet.down_blocks, 0) |
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx) |
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx) |
print(f'Number of attention layer registered {cross_attention_idx}') |
def revise_edlora_unet_attention_controller_forward(unet, controller): |
class DummyController: |
def __call__(self, *args): |
return args[0] |
def __init__(self): |
self.num_att_layers = 0 |
if controller is None: |
controller = DummyController() |
def change_forward(unet, count, place_in_unet): |
for name, layer in unet.named_children(): |
if layer.__class__.__name__ == 'Attention' and 'attn2' in name: |
layer.set_processor(EDLoRA_Control_AttnProcessor(count, place_in_unet, controller)) |
count += 1 |
else: |
count = change_forward(layer, count, place_in_unet) |
return count |
cross_attention_idx = change_forward(unet.down_blocks, 0, 'down') |
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, 'mid') |
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, 'up') |
print(f'Number of attention layer registered {cross_attention_idx}') |
controller.num_att_layers = cross_attention_idx |
class LoRALinearLayer(nn.Module): |
def __init__(self, name, original_module, rank=4, alpha=1): |
super().__init__() |
self.name = name |
rank = 32 |
if original_module.__class__.__name__ == 'Conv2d': |
in_channels, out_channels = original_module.in_channels, original_module.out_channels |
self.lora_down = torch.nn.Conv2d(in_channels, rank, (1, 1), bias=False) |
self.lora_up = torch.nn.Conv2d(rank, out_channels, (1, 1), bias=False) |
else: |
in_features, out_features = original_module.in_features, original_module.out_features |
self.lora_down = nn.Linear(in_features, rank, bias=False) |
self.lora_up = nn.Linear(rank, out_features, bias=False) |
self.register_buffer('alpha', torch.tensor(alpha)) |
m = np.load(f"orthogonal_mats/{in_features}.npy") |
idxs = np.random.choice(in_features, size = rank, replace = False) |
m = m[idxs]/2 |
with torch.no_grad(): |
self.lora_down.weight = torch.nn.Parameter(torch.tensor(m, dtype = self.lora_down.weight.dtype)) |
torch.nn.init.zeros_(self.lora_up.weight) |
for param in self.lora_down.parameters(): |
param.requires_grad = False |
self.original_forward = original_module.forward |
original_module.forward = self.forward |
def forward(self, hidden_states): |
hidden_states = self.original_forward(hidden_states) + self.alpha * self.lora_up(self.lora_down(hidden_states)) |
return hidden_states |