""" This baseline just returns heatmaps as the raw cross attentions. """ from concept_attention.flux.src.flux.sampling import prepare, unpack import torch import einops import PIL from concept_attention.image_generator import FluxGenerator from concept_attention.segmentation import SegmentationAbstractClass, add_noise_to_image, encode_image from concept_attention.utils import embed_concepts, linear_normalization class RawOutputSpaceBaseline(): """ This class implements the cross attention baseline. """ def __init__( self, model_name: str = "flux-schnell", device: str = "cuda", offload: bool = True, generator = None ): super(RawOutputSpaceBaseline, self).__init__() # Load up the flux generator if generator is None: self.generator = FluxGenerator( model_name=model_name, device=device, offload=offload, ) else: self.generator = generator # Unpack the tokenizer self.tokenizer = self.generator.t5.tokenizer def __call__( self, prompt, concepts, seed=4, num_steps=4, timesteps=None, layers=list(range(19)), softmax=False, height=1024, width=1024, guidance=0.0, ): """ Generate cross attention heatmap visualizations. Args: - prompt: str, the prompt to generate the visualizations for - seed: int, the seed to use for the visualization Returns: - attention_maps: torch.Tensor, the attention maps for the prompt - tokens: list[str], the tokens in the prompt - image: torch.Tensor, the image generated by the """ if timesteps is None: timesteps = list(range(num_steps)) if layers is None: layers = list(range(19)) # Run the image generator image, _, all_concept_heatmaps = self.generator.generate_image( width=height, height=width, num_steps=num_steps, guidance=guidance, seed=seed, prompt=prompt, concepts=concepts ) # Apply softmax if softmax: all_concept_heatmaps = torch.nn.functional.softmax(all_concept_heatmaps, dim=-2) concept_heatmaps = all_concept_heatmaps[:, layers] concept_heatmaps = einops.reduce( concept_heatmaps, "time layers batch concepts patches -> batch concepts patches", reduction="mean" ) # Convert to torch float32 concept_heatmaps = concept_heatmaps.to(torch.float32) concept_heatmaps = einops.rearrange( concept_heatmaps, "batch concepts (h w) -> batch concepts h w", h=64, w=64 ) return concept_heatmaps, image class RawOutputSpaceSegmentationModel(SegmentationAbstractClass): def __init__( self, model_name: str = "flux-schnell", device: str = "cuda", offload: bool = True, generator=None, ): """ Initialize the segmentation model. """ super(RawOutputSpaceSegmentationModel, self).__init__() if generator is not None: self.generator = generator else: # Load up the flux generator self.generator = FluxGenerator( model_name=model_name, device=device, offload=offload, ) self.is_schnell = "schnell" in model_name def segment_individual_image( self, image: PIL.Image.Image, concepts: list[str], caption: str, device: str = "cuda", offload: bool = False, num_samples: int = 1, num_steps: int = 4, noise_timestep: int = 2, seed: int = 4, width: int = 1024, height: int = 1024, stop_after_multimodal_attentions: bool = True, layers: list[int] = list(range(19)), normalize_concepts=True, softmax: bool = False, joint_attention_kwargs=None, **kwargs ): """ Takes a real image and generates a segmentation map. """ # Encode the image into the VAE latent space encoded_image_without_noise = encode_image( image, self.generator.ae, offload=offload, device=device, ) # Do N trials all_concept_heatmaps = [] for i in range(num_samples): # Add noise to image encoded_image, timesteps = add_noise_to_image( encoded_image_without_noise, num_steps=num_steps, noise_timestep=noise_timestep, seed=seed + i, width=width, height=height, device=device, is_schnell=self.is_schnell, ) # Now run the diffusion model once on the noisy image # Encode the concept vectors if offload: self.generator.t5, self.generator.clip = self.generator.t5.to(device), self.generator.clip.to(device) inp = prepare(t5=self.generator.t5, clip=self.generator.clip, img=encoded_image, prompt=caption) concept_embeddings, concept_ids, concept_vec = embed_concepts( self.generator.clip, self.generator.t5, concepts, ) inp["concepts"] = concept_embeddings.to(encoded_image.device) inp["concept_ids"] = concept_ids.to(encoded_image.device) inp["concept_vec"] = concept_vec.to(encoded_image.device) # offload TEs to CPU, load model to gpu if offload: self.generator.t5, self.generator.clip = self.generator.t5.cpu(), self.generator.clip.cpu() torch.cuda.empty_cache() self.generator.model = self.generator.model.to(device) # Denoise the intermediate images guidance_vec = torch.full((encoded_image.shape[0],), 0.0, device=encoded_image.device, dtype=encoded_image.dtype) t_curr = timesteps[0] t_prev = timesteps[1] t_vec = torch.full((encoded_image.shape[0],), t_curr, dtype=encoded_image.dtype, device=encoded_image.device) pred, _, concept_heatmaps = self.generator.model( img=inp["img"], img_ids=inp["img_ids"], txt=inp["txt"], txt_ids=inp["txt_ids"], concepts=inp["concepts"], concept_ids=inp["concept_ids"], concept_vec=inp["concept_vec"], y=inp["concept_vec"], timesteps=t_vec, guidance=guidance_vec, stop_after_multimodal_attentions=stop_after_multimodal_attentions, joint_attention_kwargs=joint_attention_kwargs, ) all_concept_heatmaps.append(concept_heatmaps) all_concept_heatmaps = torch.stack(all_concept_heatmaps, dim=0) if not stop_after_multimodal_attentions: img = inp["img"] + (t_prev - t_curr) * pred # decode latents to pixel space img = unpack(img.float(), height, width) with torch.autocast(device_type=self.generator.device.type, dtype=torch.bfloat16): img = self.generator.ae.decode(img) if self.generator.offload: self.generator.ae.decoder.cpu() torch.cuda.empty_cache() img = img.clamp(-1, 1) img = einops.rearrange(img[0], "c h w -> h w c") # reconstructed_image = PIL.Image.fromarray(img.cpu().byte().numpy()) reconstructed_image = PIL.Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy()) else: img = None reconstructed_image = None # Decode the image if offload: self.generator.model.cpu() torch.cuda.empty_cache() self.generator.ae.decoder.to(device) # if layers is not None: # # Pull out the layer index # concept_vectors = concept_vectors[layers] # image_vectors = image_vectors[layers] # Apply linear normalization to concepts # if normalize_concepts: # concept_vectors = linear_normalization(concept_vectors, dim=-2) # Apply softmax if softmax: all_concept_heatmaps = torch.nn.functional.softmax(all_concept_heatmaps, dim=-2) concept_heatmaps = all_concept_heatmaps[:, layers] concept_heatmaps = einops.reduce( concept_heatmaps, "samples layers batch concepts patches -> batch concepts patches", reduction="mean" ) # Convert to torch float32 concept_heatmaps = concept_heatmaps.to(torch.float32) concept_heatmaps = einops.rearrange( concept_heatmaps, "batch concepts (h w) -> batch concepts h w", h=64, w=64 ) return concept_heatmaps, reconstructed_image