helblazer811's picture
"Orphan branch commit with a readme"
55866f4
import PIL
import torch
from daam import trace
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils.torch_utils import randn_tensor
import matplotlib.pyplot as plt
from concept_attention.segmentation import SegmentationAbstractClass
def retrieve_latents(encoder_output, generator, sample_mode="sample"):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class DAAMStableDiffusion2SegmentationModel(SegmentationAbstractClass):
def __init__(self, device='cuda:3'):
# Load the SDXL Pipeline
model_id = 'stabilityai/stable-diffusion-2-base'
self.pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
self.pipeline = self.pipeline.to(device)
self.device = device
def _encode_image(self, image: PIL.Image.Image, timestep, height=512, width=512):
# Preprocess the image
init_image = self.pipeline.image_processor.preprocess(
image,
height=height,
width=width,
)
init_image = init_image.to(dtype=torch.float32) # Make sure float 32 cause otherwise vae encoder doesnt work
init_image = init_image.to(device=self.device)
init_latents = retrieve_latents(self.pipeline.vae.encode(init_image), generator=None)
init_latents = self.pipeline.vae.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
# Add noise
noise = randn_tensor(shape, generator=None, device=self.device, dtype=self.pipeline.dtype)
init_latents = self.pipeline.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
@torch.no_grad()
def _model_forward_pass(
self,
image,
prompt,
timestep=49,
guidance_scale=1.0,
num_inference_steps=50,
height=512,
width=512,
dtype=torch.float32,
batch_size=1,
generator=None,
):
# Set up timesteps
self.pipeline.scheduler.set_timesteps(num_inference_steps)
timestep = self.pipeline.scheduler.timesteps[timestep] # .to(device=device, dtype=dtype)
# # Encode the image
# self.pipeline(
# image,
# device=self.device,
# num_images_per_prompt=1,
# output_hidden_states=None,
# )
########################## Prepare latents ##########################
image_latents = self._encode_image(
image,
timestep
)
# Add noise at the appropriate timescale
# noise = randn_tensor(image_latents.shape, generator=generator, device=torch.device(self.device), dtype=dtype)
# noisy_latents = self.pipeline.scheduler.add_noise(image_latents, noise, timestep.unsqueeze(0))
# noisy_latents = self.pipeline.scheduler.scale_model_input(noisy_latents, timestep)
# noisy_latents = noisy_latents.to(device=self.device, dtype=dtype)
# Encode the prompt
prompt_embeds, negative_prompt_embeds = self.pipeline.encode_prompt(
prompt,
self.device,
1,
True,
None,
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
lora_scale=0.0,
# clip_skip=self.pipeline.clip_skip,
)
########################## Run forward pass ##########################
noise_pred = self.pipeline.unet(
image_latents,
timestep,
encoder_hidden_states=prompt_embeds,
timestep_cond=None,
cross_attention_kwargs=None,
added_cond_kwargs=None,
return_dict=False,
)[0]
########################## Get and save predicted image ##########################
# image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# do_denormalize = [True] * image.shape[0]
# image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# # Manually do the logic for the scheduler to get the original prediction
# s_churn = 0.0
# s_tmin = 0.0
# s_tmax = float("inf")
# s_noise = 1.0
# # Upcast to avoid precision issues when computing prev_sample
# sample = noisy_latents.to(torch.float32)
# sigma = self.pipeline.scheduler.sigmas[self.pipeline.scheduler.index_for_timestep(timestep)]
# gamma = min(s_churn / (len(self.pipeline.scheduler.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
# noise = randn_tensor(
# noise_pred.shape, dtype=noise_pred.dtype, device=noise_pred.device, generator=generator
# )
# eps = noise * s_noise
# sigma_hat = sigma * (gamma + 1)
# if gamma > 0:
# sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# pred_original_sample = sample - sigma_hat * noise_pred
# # For testing purposes get the predicted original latents and generate the image for it to verify that the image was encoded properly.
# image = self.pipeline.vae.decode(pred_original_sample / self.pipeline.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# image = self.pipeline.image_processor.postprocess(image, output_type="pil", do_denormalize=[True for _ in range(batch_size)])
return None
def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
# Cocnat the concepts into the prompt
modified_caption = caption + ", ".join([f"a {concept}" for concept in concepts])
# Run the forward pass with daam trace wrapper
concept_heatmaps = []
with trace(self.pipeline) as tc:
_ = self._model_forward_pass(
image,
caption,
timestep=49,
guidance_scale=7.0,
num_inference_steps=50,
height=512,
width=512,
dtype=torch.float32,
batch_size=1,
)
heat_map = tc.compute_global_heat_map(prompt=modified_caption)
# For each concept make a heatmap
for concept in concepts:
concept_heat_map = heat_map.compute_word_heat_map(concept).heatmap
concept_heatmaps.append(concept_heat_map)
concept_heatmaps = torch.stack(concept_heatmaps, dim=0)
return concept_heatmaps, None