Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,156 Bytes
55866f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
"""
Here we reproduce DAAM, but for Flux DiT models. This is effectively a visualization of the cross attention
layers of a Flux model.
"""
from torch import nn
import torch
import einops
from concept_attention.image_generator import FluxGenerator
from concept_attention.segmentation import SegmentationAbstractClass
class DAAM(nn.Module):
def __init__(
self,
model_name: str = "flux-schnell",
device: str = "cuda",
offload: bool = True,
):
"""
Initialize the DAAM model.
"""
super(DAAM, self).__init__()
# Load up the flux generator
self.generator = FluxGenerator(
model_name=model_name,
device=device,
offload=offload,
)
# Unpack the tokenizer
self.tokenizer = self.generator.t5.tokenizer
def __call__(
self,
prompt,
seed=4,
num_steps=4,
timesteps=None,
layers=None
):
"""
Generate cross attention heatmap visualizations.
Args:
- prompt: str, the prompt to generate the visualizations for
- seed: int, the seed to use for the visualization
Returns:
- attention_maps: torch.Tensor, the attention maps for the prompt
- tokens: list[str], the tokens in the prompt
- image: torch.Tensor, the image generated by the
"""
if timesteps is None:
timesteps = list(range(num_steps))
if layers is None:
layers = list(range(19))
# Run the tokenizer and get list of the tokens
token_strings = self.tokenizer.tokenize(prompt)
# Run the image generator
image = self.generator.generate_image(
width=1024,
height=1024,
num_steps=num_steps,
guidance=0.0,
seed=seed,
prompt=prompt,
concepts=token_strings
)
# Pull out and average the attention maps
cross_attention_maps = []
for double_block in self.generator.model.double_blocks:
cross_attention_map = torch.stack(
double_block.cross_attention_maps
).squeeze(1)
# Clear out the layer (always same)
double_block.clear_cached_vectors()
# Append to the list
cross_attention_maps.append(cross_attention_map)
# Stack layers
cross_attention_maps = torch.stack(cross_attention_maps).to(torch.float32)
# Pull out the desired timesteps
cross_attention_maps = cross_attention_maps[:, timesteps]
# Pull out the desired layers
cross_attention_maps = cross_attention_maps[layers]
# Average over layers and time
attention_maps = einops.reduce(
cross_attention_maps,
"layers time concepts height width -> concepts height width",
reduction="mean"
)
# Pull out only token length attention maps
attention_maps = attention_maps[:len(token_strings)]
return attention_maps, token_strings, image
|