Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn.functional as F | |
from torchvision.transforms import ToPILImage | |
from diffusers.models import Transformer2DModel | |
from diffusers.models.unets import UNet2DConditionModel | |
from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel | |
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock | |
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock | |
from diffusers import FluxPipeline | |
from diffusers.models.attention_processor import ( | |
AttnProcessor, | |
AttnProcessor2_0, | |
LoRAAttnProcessor, | |
LoRAAttnProcessor2_0, | |
JointAttnProcessor2_0, | |
FluxAttnProcessor2_0 | |
) | |
from modules import * | |
def cross_attn_init(): | |
AttnProcessor.__call__ = attn_call | |
AttnProcessor2_0.__call__ = attn_call2_0 | |
LoRAAttnProcessor.__call__ = lora_attn_call | |
LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0 | |
JointAttnProcessor2_0.__call__ = joint_attn_call2_0 | |
FluxAttnProcessor2_0.__call__ = flux_attn_call2_0 | |
def hook_function(name, detach=True): | |
def forward_hook(module, input, output): | |
if hasattr(module.processor, "attn_map"): | |
timestep = module.processor.timestep | |
attn_maps[timestep] = attn_maps.get(timestep, dict()) | |
attn_maps[timestep][name] = module.processor.attn_map.cpu() if detach \ | |
else module.processor.attn_map | |
del module.processor.attn_map | |
return forward_hook | |
def register_cross_attention_hook(model, hook_function, target_name): | |
for name, module in model.named_modules(): | |
if not name.endswith(target_name): | |
continue | |
if isinstance(module.processor, AttnProcessor): | |
module.processor.store_attn_map = True | |
elif isinstance(module.processor, AttnProcessor2_0): | |
module.processor.store_attn_map = True | |
elif isinstance(module.processor, LoRAAttnProcessor): | |
module.processor.store_attn_map = True | |
elif isinstance(module.processor, LoRAAttnProcessor2_0): | |
module.processor.store_attn_map = True | |
elif isinstance(module.processor, JointAttnProcessor2_0): | |
module.processor.store_attn_map = True | |
elif isinstance(module.processor, FluxAttnProcessor2_0): | |
module.processor.store_attn_map = True | |
hook = module.register_forward_hook(hook_function(name)) | |
return model | |
def replace_call_method_for_unet(model): | |
if model.__class__.__name__ == 'UNet2DConditionModel': | |
model.forward = UNet2DConditionModelForward.__get__(model, UNet2DConditionModel) | |
for name, layer in model.named_children(): | |
if layer.__class__.__name__ == 'Transformer2DModel': | |
layer.forward = Transformer2DModelForward.__get__(layer, Transformer2DModel) | |
elif layer.__class__.__name__ == 'BasicTransformerBlock': | |
layer.forward = BasicTransformerBlockForward.__get__(layer, BasicTransformerBlock) | |
replace_call_method_for_unet(layer) | |
return model | |
def replace_call_method_for_sd3(model): | |
if model.__class__.__name__ == 'SD3Transformer2DModel': | |
model.forward = SD3Transformer2DModelForward.__get__(model, SD3Transformer2DModel) | |
for name, layer in model.named_children(): | |
if layer.__class__.__name__ == 'JointTransformerBlock': | |
layer.forward = JointTransformerBlockForward.__get__(layer, JointTransformerBlock) | |
replace_call_method_for_sd3(layer) | |
return model | |
def replace_call_method_for_flux(model): | |
if model.__class__.__name__ == 'FluxTransformer2DModel': | |
model.forward = FluxTransformer2DModelForward.__get__(model, FluxTransformer2DModel) | |
for name, layer in model.named_children(): | |
if layer.__class__.__name__ == 'FluxTransformerBlock': | |
layer.forward = FluxTransformerBlockForward.__get__(layer, FluxTransformerBlock) | |
replace_call_method_for_flux(layer) | |
return model | |
def init_pipeline(pipeline): | |
if 'transformer' in vars(pipeline).keys(): | |
if pipeline.transformer.__class__.__name__ == 'SD3Transformer2DModel': | |
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn') | |
pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer) | |
elif pipeline.transformer.__class__.__name__ == 'FluxTransformer2DModel': | |
FluxPipeline.__call__ = FluxPipeline_call | |
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn') | |
pipeline.transformer = replace_call_method_for_flux(pipeline.transformer) | |
else: | |
if pipeline.unet.__class__.__name__ == 'UNet2DConditionModel': | |
pipeline.unet = register_cross_attention_hook(pipeline.unet, hook_function, 'attn2') | |
pipeline.unet = replace_call_method_for_unet(pipeline.unet) | |
return pipeline | |
def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unconditional=True): | |
to_pil = ToPILImage() | |
token_ids = tokenizer(prompts)['input_ids'] | |
total_tokens = [] | |
for token_id in token_ids: | |
total_tokens.append(tokenizer.convert_ids_to_tokens(token_id)) | |
if not os.path.exists(base_dir): | |
os.mkdir(base_dir) | |
total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1) | |
if unconditional: | |
total_attn_map = total_attn_map.chunk(2)[1] # (batch, height, width, attn_dim) | |
total_attn_map = total_attn_map.permute(0, 3, 1, 2) | |
total_attn_map = torch.zeros_like(total_attn_map) | |
total_attn_map_shape = total_attn_map.shape[-2:] | |
total_attn_map_number = 0 | |
for timestep, layers in attn_maps.items(): | |
timestep_dir = os.path.join(base_dir, f'{timestep}') | |
if not os.path.exists(timestep_dir): | |
os.mkdir(timestep_dir) | |
for layer, attn_map in layers.items(): | |
layer_dir = os.path.join(timestep_dir, f'{layer}') | |
if not os.path.exists(layer_dir): | |
os.mkdir(layer_dir) | |
attn_map = attn_map.sum(1).squeeze(1) | |
attn_map = attn_map.permute(0, 3, 1, 2) | |
if unconditional: | |
attn_map = attn_map.chunk(2)[1] | |
resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
total_attn_map += resized_attn_map | |
total_attn_map_number += 1 | |
total_attn_map /= total_attn_map_number | |
final_attn_map = [] | |
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)): | |
batch_dir = os.path.join(base_dir, f'batch-{batch}') | |
if not os.path.exists(batch_dir): | |
os.mkdir(batch_dir) | |
startofword = True | |
for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])): | |
if '</w>' in token: | |
token = token.replace('</w>', '') | |
if startofword: | |
token = '<' + token + '>' | |
else: | |
token = '-' + token + '>' | |
startofword = True | |
elif token != '<|startoftext|>' and token != '<|endoftext|>': | |
if startofword: | |
token = '<' + token + '-' | |
startofword = False | |
else: | |
token = '-' + token + '-' | |
final_attn_map.append((to_pil(a.to(torch.float32)), f'{i}-{token}')) | |
return final_attn_map |