Spaces:
Running
on
Zero
Running
on
Zero
import PIL | |
import torch | |
from daam import trace | |
from diffusers import DiffusionPipeline, StableDiffusionPipeline | |
from diffusers.utils.torch_utils import randn_tensor | |
import matplotlib.pyplot as plt | |
from concept_attention.segmentation import SegmentationAbstractClass | |
def retrieve_latents(encoder_output, generator, sample_mode="sample"): | |
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
return encoder_output.latent_dist.sample(generator) | |
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
return encoder_output.latent_dist.mode() | |
elif hasattr(encoder_output, "latents"): | |
return encoder_output.latents | |
else: | |
raise AttributeError("Could not access latents of provided encoder_output") | |
class DAAMStableDiffusion2SegmentationModel(SegmentationAbstractClass): | |
def __init__(self, device='cuda:3'): | |
# Load the SDXL Pipeline | |
model_id = 'stabilityai/stable-diffusion-2-base' | |
self.pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) | |
self.pipeline = self.pipeline.to(device) | |
self.device = device | |
def _encode_image(self, image: PIL.Image.Image, timestep, height=512, width=512): | |
# Preprocess the image | |
init_image = self.pipeline.image_processor.preprocess( | |
image, | |
height=height, | |
width=width, | |
) | |
init_image = init_image.to(dtype=torch.float32) # Make sure float 32 cause otherwise vae encoder doesnt work | |
init_image = init_image.to(device=self.device) | |
init_latents = retrieve_latents(self.pipeline.vae.encode(init_image), generator=None) | |
init_latents = self.pipeline.vae.config.scaling_factor * init_latents | |
init_latents = torch.cat([init_latents], dim=0) | |
shape = init_latents.shape | |
# Add noise | |
noise = randn_tensor(shape, generator=None, device=self.device, dtype=self.pipeline.dtype) | |
init_latents = self.pipeline.scheduler.add_noise(init_latents, noise, timestep) | |
latents = init_latents | |
return latents | |
def _model_forward_pass( | |
self, | |
image, | |
prompt, | |
timestep=49, | |
guidance_scale=1.0, | |
num_inference_steps=50, | |
height=512, | |
width=512, | |
dtype=torch.float32, | |
batch_size=1, | |
generator=None, | |
): | |
# Set up timesteps | |
self.pipeline.scheduler.set_timesteps(num_inference_steps) | |
timestep = self.pipeline.scheduler.timesteps[timestep] # .to(device=device, dtype=dtype) | |
# # Encode the image | |
# self.pipeline( | |
# image, | |
# device=self.device, | |
# num_images_per_prompt=1, | |
# output_hidden_states=None, | |
# ) | |
########################## Prepare latents ########################## | |
image_latents = self._encode_image( | |
image, | |
timestep | |
) | |
# Add noise at the appropriate timescale | |
# noise = randn_tensor(image_latents.shape, generator=generator, device=torch.device(self.device), dtype=dtype) | |
# noisy_latents = self.pipeline.scheduler.add_noise(image_latents, noise, timestep.unsqueeze(0)) | |
# noisy_latents = self.pipeline.scheduler.scale_model_input(noisy_latents, timestep) | |
# noisy_latents = noisy_latents.to(device=self.device, dtype=dtype) | |
# Encode the prompt | |
prompt_embeds, negative_prompt_embeds = self.pipeline.encode_prompt( | |
prompt, | |
self.device, | |
1, | |
True, | |
None, | |
# prompt_embeds=prompt_embeds, | |
# negative_prompt_embeds=negative_prompt_embeds, | |
lora_scale=0.0, | |
# clip_skip=self.pipeline.clip_skip, | |
) | |
########################## Run forward pass ########################## | |
noise_pred = self.pipeline.unet( | |
image_latents, | |
timestep, | |
encoder_hidden_states=prompt_embeds, | |
timestep_cond=None, | |
cross_attention_kwargs=None, | |
added_cond_kwargs=None, | |
return_dict=False, | |
)[0] | |
########################## Get and save predicted image ########################## | |
# image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] | |
# do_denormalize = [True] * image.shape[0] | |
# image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | |
# # Manually do the logic for the scheduler to get the original prediction | |
# s_churn = 0.0 | |
# s_tmin = 0.0 | |
# s_tmax = float("inf") | |
# s_noise = 1.0 | |
# # Upcast to avoid precision issues when computing prev_sample | |
# sample = noisy_latents.to(torch.float32) | |
# sigma = self.pipeline.scheduler.sigmas[self.pipeline.scheduler.index_for_timestep(timestep)] | |
# gamma = min(s_churn / (len(self.pipeline.scheduler.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 | |
# noise = randn_tensor( | |
# noise_pred.shape, dtype=noise_pred.dtype, device=noise_pred.device, generator=generator | |
# ) | |
# eps = noise * s_noise | |
# sigma_hat = sigma * (gamma + 1) | |
# if gamma > 0: | |
# sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 | |
# pred_original_sample = sample - sigma_hat * noise_pred | |
# # For testing purposes get the predicted original latents and generate the image for it to verify that the image was encoded properly. | |
# image = self.pipeline.vae.decode(pred_original_sample / self.pipeline.vae.config.scaling_factor, return_dict=False, generator=generator)[0] | |
# image = self.pipeline.image_processor.postprocess(image, output_type="pil", do_denormalize=[True for _ in range(batch_size)]) | |
return None | |
def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): | |
# Cocnat the concepts into the prompt | |
modified_caption = caption + ", ".join([f"a {concept}" for concept in concepts]) | |
# Run the forward pass with daam trace wrapper | |
concept_heatmaps = [] | |
with trace(self.pipeline) as tc: | |
_ = self._model_forward_pass( | |
image, | |
caption, | |
timestep=49, | |
guidance_scale=7.0, | |
num_inference_steps=50, | |
height=512, | |
width=512, | |
dtype=torch.float32, | |
batch_size=1, | |
) | |
heat_map = tc.compute_global_heat_map(prompt=modified_caption) | |
# For each concept make a heatmap | |
for concept in concepts: | |
concept_heat_map = heat_map.compute_word_heat_map(concept).heatmap | |
concept_heatmaps.append(concept_heat_map) | |
concept_heatmaps = torch.stack(concept_heatmaps, dim=0) | |
return concept_heatmaps, None | |