Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,039 Bytes
55866f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import PIL
import torch
from daam import trace
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils.torch_utils import randn_tensor
import matplotlib.pyplot as plt
from concept_attention.segmentation import SegmentationAbstractClass
def retrieve_latents(encoder_output, generator, sample_mode="sample"):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class DAAMStableDiffusion2SegmentationModel(SegmentationAbstractClass):
def __init__(self, device='cuda:3'):
# Load the SDXL Pipeline
model_id = 'stabilityai/stable-diffusion-2-base'
self.pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
self.pipeline = self.pipeline.to(device)
self.device = device
def _encode_image(self, image: PIL.Image.Image, timestep, height=512, width=512):
# Preprocess the image
init_image = self.pipeline.image_processor.preprocess(
image,
height=height,
width=width,
)
init_image = init_image.to(dtype=torch.float32) # Make sure float 32 cause otherwise vae encoder doesnt work
init_image = init_image.to(device=self.device)
init_latents = retrieve_latents(self.pipeline.vae.encode(init_image), generator=None)
init_latents = self.pipeline.vae.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
# Add noise
noise = randn_tensor(shape, generator=None, device=self.device, dtype=self.pipeline.dtype)
init_latents = self.pipeline.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
@torch.no_grad()
def _model_forward_pass(
self,
image,
prompt,
timestep=49,
guidance_scale=1.0,
num_inference_steps=50,
height=512,
width=512,
dtype=torch.float32,
batch_size=1,
generator=None,
):
# Set up timesteps
self.pipeline.scheduler.set_timesteps(num_inference_steps)
timestep = self.pipeline.scheduler.timesteps[timestep] # .to(device=device, dtype=dtype)
# # Encode the image
# self.pipeline(
# image,
# device=self.device,
# num_images_per_prompt=1,
# output_hidden_states=None,
# )
########################## Prepare latents ##########################
image_latents = self._encode_image(
image,
timestep
)
# Add noise at the appropriate timescale
# noise = randn_tensor(image_latents.shape, generator=generator, device=torch.device(self.device), dtype=dtype)
# noisy_latents = self.pipeline.scheduler.add_noise(image_latents, noise, timestep.unsqueeze(0))
# noisy_latents = self.pipeline.scheduler.scale_model_input(noisy_latents, timestep)
# noisy_latents = noisy_latents.to(device=self.device, dtype=dtype)
# Encode the prompt
prompt_embeds, negative_prompt_embeds = self.pipeline.encode_prompt(
prompt,
self.device,
1,
True,
None,
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
lora_scale=0.0,
# clip_skip=self.pipeline.clip_skip,
)
########################## Run forward pass ##########################
noise_pred = self.pipeline.unet(
image_latents,
timestep,
encoder_hidden_states=prompt_embeds,
timestep_cond=None,
cross_attention_kwargs=None,
added_cond_kwargs=None,
return_dict=False,
)[0]
########################## Get and save predicted image ##########################
# image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# do_denormalize = [True] * image.shape[0]
# image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# # Manually do the logic for the scheduler to get the original prediction
# s_churn = 0.0
# s_tmin = 0.0
# s_tmax = float("inf")
# s_noise = 1.0
# # Upcast to avoid precision issues when computing prev_sample
# sample = noisy_latents.to(torch.float32)
# sigma = self.pipeline.scheduler.sigmas[self.pipeline.scheduler.index_for_timestep(timestep)]
# gamma = min(s_churn / (len(self.pipeline.scheduler.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
# noise = randn_tensor(
# noise_pred.shape, dtype=noise_pred.dtype, device=noise_pred.device, generator=generator
# )
# eps = noise * s_noise
# sigma_hat = sigma * (gamma + 1)
# if gamma > 0:
# sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# pred_original_sample = sample - sigma_hat * noise_pred
# # For testing purposes get the predicted original latents and generate the image for it to verify that the image was encoded properly.
# image = self.pipeline.vae.decode(pred_original_sample / self.pipeline.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# image = self.pipeline.image_processor.postprocess(image, output_type="pil", do_denormalize=[True for _ in range(batch_size)])
return None
def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
# Cocnat the concepts into the prompt
modified_caption = caption + ", ".join([f"a {concept}" for concept in concepts])
# Run the forward pass with daam trace wrapper
concept_heatmaps = []
with trace(self.pipeline) as tc:
_ = self._model_forward_pass(
image,
caption,
timestep=49,
guidance_scale=7.0,
num_inference_steps=50,
height=512,
width=512,
dtype=torch.float32,
batch_size=1,
)
heat_map = tc.compute_global_heat_map(prompt=modified_caption)
# For each concept make a heatmap
for concept in concepts:
concept_heat_map = heat_map.compute_word_heat_map(concept).heatmap
concept_heatmaps.append(concept_heat_map)
concept_heatmaps = torch.stack(concept_heatmaps, dim=0)
return concept_heatmaps, None
|