helblazer811's picture
"Orphan branch commit with a readme"
55866f4
"""
Here we reproduce DAAM, but for Flux DiT models. This is effectively a visualization of the cross attention
layers of a Flux model.
"""
from torch import nn
import torch
import einops
from concept_attention.image_generator import FluxGenerator
from concept_attention.segmentation import SegmentationAbstractClass
class DAAM(nn.Module):
def __init__(
self,
model_name: str = "flux-schnell",
device: str = "cuda",
offload: bool = True,
):
"""
Initialize the DAAM model.
"""
super(DAAM, self).__init__()
# Load up the flux generator
self.generator = FluxGenerator(
model_name=model_name,
device=device,
offload=offload,
)
# Unpack the tokenizer
self.tokenizer = self.generator.t5.tokenizer
def __call__(
self,
prompt,
seed=4,
num_steps=4,
timesteps=None,
layers=None
):
"""
Generate cross attention heatmap visualizations.
Args:
- prompt: str, the prompt to generate the visualizations for
- seed: int, the seed to use for the visualization
Returns:
- attention_maps: torch.Tensor, the attention maps for the prompt
- tokens: list[str], the tokens in the prompt
- image: torch.Tensor, the image generated by the
"""
if timesteps is None:
timesteps = list(range(num_steps))
if layers is None:
layers = list(range(19))
# Run the tokenizer and get list of the tokens
token_strings = self.tokenizer.tokenize(prompt)
# Run the image generator
image = self.generator.generate_image(
width=1024,
height=1024,
num_steps=num_steps,
guidance=0.0,
seed=seed,
prompt=prompt,
concepts=token_strings
)
# Pull out and average the attention maps
cross_attention_maps = []
for double_block in self.generator.model.double_blocks:
cross_attention_map = torch.stack(
double_block.cross_attention_maps
).squeeze(1)
# Clear out the layer (always same)
double_block.clear_cached_vectors()
# Append to the list
cross_attention_maps.append(cross_attention_map)
# Stack layers
cross_attention_maps = torch.stack(cross_attention_maps).to(torch.float32)
# Pull out the desired timesteps
cross_attention_maps = cross_attention_maps[:, timesteps]
# Pull out the desired layers
cross_attention_maps = cross_attention_maps[layers]
# Average over layers and time
attention_maps = einops.reduce(
cross_attention_maps,
"layers time concepts height width -> concepts height width",
reduction="mean"
)
# Pull out only token length attention maps
attention_maps = attention_maps[:len(token_strings)]
return attention_maps, token_strings, image