import PIL import torch from daam import trace from diffusers import DiffusionPipeline from diffusers.utils.torch_utils import randn_tensor from concept_attention.segmentation import SegmentationAbstractClass class DAAMStableDiffusionXLSegmentationModel(SegmentationAbstractClass): def __init__(self, device='cuda:3'): # Load the SDXL Pipeline model_id = 'stabilityai/stable-diffusion-xl-base-1.0' self.pipeline = DiffusionPipeline.from_pretrained( model_id, use_auth_token=True, torch_dtype=torch.float32, use_safetensors=True ) self.pipeline = self.pipeline.to(device) self.device = device def _encode_prompt(self, prompt, guidance_scale=0.0, device="cuda:0"): # Get the prompt embeddings prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipeline.encode_prompt( prompt, None, device, True, negative_prompt=None, # lora_scale=None, # clip_skip=None, ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def _encode_image(self, image: PIL.Image.Image, generator=None): image_latents = self.pipeline.vae.encode(image) image_latents = image_latents.latent_dist.sample(generator) image_latents = self.pipeline.vae.config.scaling_factor * image_latents return image_latents def _process_added_kwargs( self, prompt_embeds, pooled_prompt_embeds, height=512, width=512, ): add_text_embeds = pooled_prompt_embeds if self.pipeline.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.pipeline.text_encoder_2.config.projection_dim add_time_ids = self.pipeline._get_add_time_ids( (height, width), (0, 0), (height, width), dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) # Proprocess the text embeddings added_cond_kwargs = { "time_ids": add_time_ids.to(device=self.device), "text_embeds": pooled_prompt_embeds.to(device=self.device), } return added_cond_kwargs @torch.no_grad() def _model_forward_pass( self, image, prompt, timestep=49, guidance_scale=1.0, num_inference_steps=50, height=512, width=512, dtype=torch.float32, batch_size=1, generator=None, ): # Set up timesteps self.pipeline.scheduler.set_timesteps(num_inference_steps) ########################## Prepare latents ########################## init_image = self.pipeline.image_processor.preprocess( image, height=height, width=width, # crops_coords=None, # resize_mode="default" ) init_image = init_image.to(dtype=torch.float32) # Make sure float 32 cause otherwise vae encoder doesnt work init_image = init_image.to(device=self.device) initial_image_latents = self._encode_image(init_image) # Figure out the number fo steps to do timestep = self.pipeline.scheduler.timesteps[timestep] # Encode the prompt prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( prompt, guidance_scale=guidance_scale, device=self.device ) # Proprocess the text embeddings added_cond_kwargs = self._process_added_kwargs( prompt_embeds, pooled_prompt_embeds, width=width, height=height ) # Add noise at the appropriate timescale noise = randn_tensor(initial_image_latents.shape, device=torch.device(self.device), dtype=dtype) noisy_latents = self.pipeline.scheduler.add_noise(initial_image_latents, noise, timestep.unsqueeze(0)) noisy_latents = self.pipeline.scheduler.scale_model_input(noisy_latents, timestep) noisy_latents = noisy_latents.to(device=self.device, dtype=dtype) ########################## Run forward pass ########################## noise_pred = self.pipeline.unet( noisy_latents, timestep, encoder_hidden_states=prompt_embeds, timestep_cond=None, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] ########################## Get and save predicted image ########################## # # Manually do the logic for the scheduler to get the original prediction # s_churn = 0.0 # s_tmin = 0.0 # s_tmax = float("inf") # s_noise = 1.0 # # Upcast to avoid precision issues when computing prev_sample # sample = noisy_latents.to(torch.float32) # sigma = self.pipeline.scheduler.sigmas[self.pipeline.scheduler.index_for_timestep(timestep)] # gamma = min(s_churn / (len(self.pipeline.scheduler.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 # noise = randn_tensor( # noise_pred.shape, dtype=noise_pred.dtype, device=noise_pred.device, generator=generator # ) # eps = noise * s_noise # sigma_hat = sigma * (gamma + 1) # if gamma > 0: # sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # pred_original_sample = sample - sigma_hat * noise_pred # # For testing purposes get the predicted original latents and generate the image for it to verify that the image was encoded properly. # image = self.pipeline.vae.decode(pred_original_sample / self.pipeline.vae.config.scaling_factor, return_dict=False, generator=generator)[0] # image = self.pipeline.image_processor.postprocess(image, output_type="pil", do_denormalize=[True for _ in range(batch_size)]) return None def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, num_samples=1, num_inference_steps=50, **kwargs): # Cocnat the concepts into the prompt modified_caption = caption+ "," + ", ".join([f"a {concept}" for concept in concepts]) # Run the forward pass with daam trace wrapper concept_heatmaps = [] if num_samples > 1: timesteps = [49 for _ in range(num_samples)] # timesteps = list(range(num_samples)) else: timesteps = [49] all_heatmaps = [] for timestep in timesteps: with trace(self.pipeline) as tc: _ = self._model_forward_pass( image, modified_caption, timestep=timestep, guidance_scale=7.0, num_inference_steps=num_inference_steps, height=512, width=512, dtype=torch.float32, batch_size=1, ) print(f"Modified Caption: {modified_caption}") heat_map = tc.compute_global_heat_map(prompt=modified_caption) concept_heatmaps = [] # For each concept make a heatmap for concept in concepts: concept_heat_map = heat_map.compute_word_heat_map(concept).heatmap concept_heatmaps.append(concept_heat_map) concept_heatmaps = torch.stack(concept_heatmaps, dim=0) all_heatmaps.append(concept_heatmaps) all_heatmaps = torch.stack(all_heatmaps, dim=0) all_heatmaps = all_heatmaps.mean(0) return all_heatmaps, None