Spaces:
Running
on
Zero
Running
on
Zero
""" | |
This baseline just returns heatmaps as the raw cross attentions. | |
""" | |
from concept_attention.flux.src.flux.sampling import prepare, unpack | |
import torch | |
import einops | |
import PIL | |
from concept_attention.image_generator import FluxGenerator | |
from concept_attention.segmentation import SegmentationAbstractClass, add_noise_to_image, encode_image | |
from concept_attention.utils import embed_concepts, linear_normalization | |
class RawOutputSpaceBaseline(): | |
""" | |
This class implements the cross attention baseline. | |
""" | |
def __init__( | |
self, | |
model_name: str = "flux-schnell", | |
device: str = "cuda", | |
offload: bool = True, | |
generator = None | |
): | |
super(RawOutputSpaceBaseline, self).__init__() | |
# Load up the flux generator | |
if generator is None: | |
self.generator = FluxGenerator( | |
model_name=model_name, | |
device=device, | |
offload=offload, | |
) | |
else: | |
self.generator = generator | |
# Unpack the tokenizer | |
self.tokenizer = self.generator.t5.tokenizer | |
def __call__( | |
self, | |
prompt, | |
concepts, | |
seed=4, | |
num_steps=4, | |
timesteps=None, | |
layers=list(range(19)), | |
softmax=False, | |
height=1024, | |
width=1024, | |
guidance=0.0, | |
): | |
""" | |
Generate cross attention heatmap visualizations. | |
Args: | |
- prompt: str, the prompt to generate the visualizations for | |
- seed: int, the seed to use for the visualization | |
Returns: | |
- attention_maps: torch.Tensor, the attention maps for the prompt | |
- tokens: list[str], the tokens in the prompt | |
- image: torch.Tensor, the image generated by the | |
""" | |
if timesteps is None: | |
timesteps = list(range(num_steps)) | |
if layers is None: | |
layers = list(range(19)) | |
# Run the image generator | |
image, _, all_concept_heatmaps = self.generator.generate_image( | |
width=height, | |
height=width, | |
num_steps=num_steps, | |
guidance=guidance, | |
seed=seed, | |
prompt=prompt, | |
concepts=concepts | |
) | |
# Apply softmax | |
if softmax: | |
all_concept_heatmaps = torch.nn.functional.softmax(all_concept_heatmaps, dim=-2) | |
concept_heatmaps = all_concept_heatmaps[:, layers] | |
concept_heatmaps = einops.reduce( | |
concept_heatmaps, | |
"time layers batch concepts patches -> batch concepts patches", | |
reduction="mean" | |
) | |
# Convert to torch float32 | |
concept_heatmaps = concept_heatmaps.to(torch.float32) | |
concept_heatmaps = einops.rearrange( | |
concept_heatmaps, | |
"batch concepts (h w) -> batch concepts h w", | |
h=64, | |
w=64 | |
) | |
return concept_heatmaps, image | |
class RawOutputSpaceSegmentationModel(SegmentationAbstractClass): | |
def __init__( | |
self, | |
model_name: str = "flux-schnell", | |
device: str = "cuda", | |
offload: bool = True, | |
generator=None, | |
): | |
""" | |
Initialize the segmentation model. | |
""" | |
super(RawOutputSpaceSegmentationModel, self).__init__() | |
if generator is not None: | |
self.generator = generator | |
else: | |
# Load up the flux generator | |
self.generator = FluxGenerator( | |
model_name=model_name, | |
device=device, | |
offload=offload, | |
) | |
self.is_schnell = "schnell" in model_name | |
def segment_individual_image( | |
self, | |
image: PIL.Image.Image, | |
concepts: list[str], | |
caption: str, | |
device: str = "cuda", | |
offload: bool = False, | |
num_samples: int = 1, | |
num_steps: int = 4, | |
noise_timestep: int = 2, | |
seed: int = 4, | |
width: int = 1024, | |
height: int = 1024, | |
stop_after_multimodal_attentions: bool = True, | |
layers: list[int] = list(range(19)), | |
normalize_concepts=True, | |
softmax: bool = False, | |
joint_attention_kwargs=None, | |
**kwargs | |
): | |
""" | |
Takes a real image and generates a segmentation map. | |
""" | |
# Encode the image into the VAE latent space | |
encoded_image_without_noise = encode_image( | |
image, | |
self.generator.ae, | |
offload=offload, | |
device=device, | |
) | |
# Do N trials | |
all_concept_heatmaps = [] | |
for i in range(num_samples): | |
# Add noise to image | |
encoded_image, timesteps = add_noise_to_image( | |
encoded_image_without_noise, | |
num_steps=num_steps, | |
noise_timestep=noise_timestep, | |
seed=seed + i, | |
width=width, | |
height=height, | |
device=device, | |
is_schnell=self.is_schnell, | |
) | |
# Now run the diffusion model once on the noisy image | |
# Encode the concept vectors | |
if offload: | |
self.generator.t5, self.generator.clip = self.generator.t5.to(device), self.generator.clip.to(device) | |
inp = prepare(t5=self.generator.t5, clip=self.generator.clip, img=encoded_image, prompt=caption) | |
concept_embeddings, concept_ids, concept_vec = embed_concepts( | |
self.generator.clip, | |
self.generator.t5, | |
concepts, | |
) | |
inp["concepts"] = concept_embeddings.to(encoded_image.device) | |
inp["concept_ids"] = concept_ids.to(encoded_image.device) | |
inp["concept_vec"] = concept_vec.to(encoded_image.device) | |
# offload TEs to CPU, load model to gpu | |
if offload: | |
self.generator.t5, self.generator.clip = self.generator.t5.cpu(), self.generator.clip.cpu() | |
torch.cuda.empty_cache() | |
self.generator.model = self.generator.model.to(device) | |
# Denoise the intermediate images | |
guidance_vec = torch.full((encoded_image.shape[0],), 0.0, device=encoded_image.device, dtype=encoded_image.dtype) | |
t_curr = timesteps[0] | |
t_prev = timesteps[1] | |
t_vec = torch.full((encoded_image.shape[0],), t_curr, dtype=encoded_image.dtype, device=encoded_image.device) | |
pred, _, concept_heatmaps = self.generator.model( | |
img=inp["img"], | |
img_ids=inp["img_ids"], | |
txt=inp["txt"], | |
txt_ids=inp["txt_ids"], | |
concepts=inp["concepts"], | |
concept_ids=inp["concept_ids"], | |
concept_vec=inp["concept_vec"], | |
y=inp["concept_vec"], | |
timesteps=t_vec, | |
guidance=guidance_vec, | |
stop_after_multimodal_attentions=stop_after_multimodal_attentions, | |
joint_attention_kwargs=joint_attention_kwargs, | |
) | |
all_concept_heatmaps.append(concept_heatmaps) | |
all_concept_heatmaps = torch.stack(all_concept_heatmaps, dim=0) | |
if not stop_after_multimodal_attentions: | |
img = inp["img"] + (t_prev - t_curr) * pred | |
# decode latents to pixel space | |
img = unpack(img.float(), height, width) | |
with torch.autocast(device_type=self.generator.device.type, dtype=torch.bfloat16): | |
img = self.generator.ae.decode(img) | |
if self.generator.offload: | |
self.generator.ae.decoder.cpu() | |
torch.cuda.empty_cache() | |
img = img.clamp(-1, 1) | |
img = einops.rearrange(img[0], "c h w -> h w c") | |
# reconstructed_image = PIL.Image.fromarray(img.cpu().byte().numpy()) | |
reconstructed_image = PIL.Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy()) | |
else: | |
img = None | |
reconstructed_image = None | |
# Decode the image | |
if offload: | |
self.generator.model.cpu() | |
torch.cuda.empty_cache() | |
self.generator.ae.decoder.to(device) | |
# if layers is not None: | |
# # Pull out the layer index | |
# concept_vectors = concept_vectors[layers] | |
# image_vectors = image_vectors[layers] | |
# Apply linear normalization to concepts | |
# if normalize_concepts: | |
# concept_vectors = linear_normalization(concept_vectors, dim=-2) | |
# Apply softmax | |
if softmax: | |
all_concept_heatmaps = torch.nn.functional.softmax(all_concept_heatmaps, dim=-2) | |
concept_heatmaps = all_concept_heatmaps[:, layers] | |
concept_heatmaps = einops.reduce( | |
concept_heatmaps, | |
"samples layers batch concepts patches -> batch concepts patches", | |
reduction="mean" | |
) | |
# Convert to torch float32 | |
concept_heatmaps = concept_heatmaps.to(torch.float32) | |
concept_heatmaps = einops.rearrange( | |
concept_heatmaps, | |
"batch concepts (h w) -> batch concepts h w", | |
h=64, | |
w=64 | |
) | |
return concept_heatmaps, reconstructed_image |