""" Here we reproduce DAAM, but for Flux DiT models. This is effectively a visualization of the cross attention layers of a Flux model. """ from torch import nn import torch import einops from concept_attention.image_generator import FluxGenerator from concept_attention.segmentation import SegmentationAbstractClass class DAAM(nn.Module): def __init__( self, model_name: str = "flux-schnell", device: str = "cuda", offload: bool = True, ): """ Initialize the DAAM model. """ super(DAAM, self).__init__() # Load up the flux generator self.generator = FluxGenerator( model_name=model_name, device=device, offload=offload, ) # Unpack the tokenizer self.tokenizer = self.generator.t5.tokenizer def __call__( self, prompt, seed=4, num_steps=4, timesteps=None, layers=None ): """ Generate cross attention heatmap visualizations. Args: - prompt: str, the prompt to generate the visualizations for - seed: int, the seed to use for the visualization Returns: - attention_maps: torch.Tensor, the attention maps for the prompt - tokens: list[str], the tokens in the prompt - image: torch.Tensor, the image generated by the """ if timesteps is None: timesteps = list(range(num_steps)) if layers is None: layers = list(range(19)) # Run the tokenizer and get list of the tokens token_strings = self.tokenizer.tokenize(prompt) # Run the image generator image = self.generator.generate_image( width=1024, height=1024, num_steps=num_steps, guidance=0.0, seed=seed, prompt=prompt, concepts=token_strings ) # Pull out and average the attention maps cross_attention_maps = [] for double_block in self.generator.model.double_blocks: cross_attention_map = torch.stack( double_block.cross_attention_maps ).squeeze(1) # Clear out the layer (always same) double_block.clear_cached_vectors() # Append to the list cross_attention_maps.append(cross_attention_map) # Stack layers cross_attention_maps = torch.stack(cross_attention_maps).to(torch.float32) # Pull out the desired timesteps cross_attention_maps = cross_attention_maps[:, timesteps] # Pull out the desired layers cross_attention_maps = cross_attention_maps[layers] # Average over layers and time attention_maps = einops.reduce( cross_attention_maps, "layers time concepts height width -> concepts height width", reduction="mean" ) # Pull out only token length attention maps attention_maps = attention_maps[:len(token_strings)] return attention_maps, token_strings, image