File size: 3,156 Bytes
55866f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
    Here we reproduce DAAM, but for Flux DiT models. This is effectively a visualization of the cross attention
    layers of a Flux model. 
"""
from torch import nn
import torch
import einops

from concept_attention.image_generator import FluxGenerator
from concept_attention.segmentation import SegmentationAbstractClass

class DAAM(nn.Module):

    def __init__(
        self,
        model_name: str = "flux-schnell",
        device: str = "cuda",
        offload: bool = True,
    ):
        """
            Initialize the DAAM model.
        """
        super(DAAM, self).__init__()
        # Load up the flux generator
        self.generator = FluxGenerator(
            model_name=model_name,
            device=device,
            offload=offload,
        )
        # Unpack the tokenizer
        self.tokenizer = self.generator.t5.tokenizer

    def __call__(
        self,
        prompt,
        seed=4,
        num_steps=4,
        timesteps=None,
        layers=None
    ):
        """
            Generate cross attention heatmap visualizations. 

            Args:
            - prompt: str, the prompt to generate the visualizations for
            - seed: int, the seed to use for the visualization

            Returns:
            - attention_maps: torch.Tensor, the attention maps for the prompt
            - tokens: list[str], the tokens in the prompt
            - image: torch.Tensor, the image generated by the
        """
        if timesteps is None:
            timesteps = list(range(num_steps))
        if layers is None:
            layers = list(range(19))
        # Run the tokenizer and get list of the tokens
        token_strings = self.tokenizer.tokenize(prompt)
        # Run the image generator
        image = self.generator.generate_image(
            width=1024,
            height=1024,
            num_steps=num_steps,
            guidance=0.0,
            seed=seed,
            prompt=prompt,
            concepts=token_strings
        )
        # Pull out and average the attention maps
        cross_attention_maps = []
        for double_block in self.generator.model.double_blocks:
            cross_attention_map = torch.stack(
                double_block.cross_attention_maps
            ).squeeze(1)
            # Clear out the layer (always same)
            double_block.clear_cached_vectors()
            # Append to the list
            cross_attention_maps.append(cross_attention_map)
        # Stack layers
        cross_attention_maps = torch.stack(cross_attention_maps).to(torch.float32)
        # Pull out the desired timesteps
        cross_attention_maps = cross_attention_maps[:, timesteps]
        # Pull out the desired layers
        cross_attention_maps = cross_attention_maps[layers]
        # Average over layers and time
        attention_maps = einops.reduce(
            cross_attention_maps,
            "layers time concepts height width -> concepts height width",
            reduction="mean"
        )
        # Pull out only token length attention maps
        attention_maps = attention_maps[:len(token_strings)]

        return attention_maps, token_strings, image