Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,993 Bytes
55866f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import PIL
import torch
from daam import trace
from diffusers import DiffusionPipeline
from diffusers.utils.torch_utils import randn_tensor
from concept_attention.segmentation import SegmentationAbstractClass
class DAAMStableDiffusionXLSegmentationModel(SegmentationAbstractClass):
def __init__(self, device='cuda:3'):
# Load the SDXL Pipeline
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
self.pipeline = DiffusionPipeline.from_pretrained(
model_id,
use_auth_token=True,
torch_dtype=torch.float32,
use_safetensors=True
)
self.pipeline = self.pipeline.to(device)
self.device = device
def _encode_prompt(self, prompt, guidance_scale=0.0, device="cuda:0"):
# Get the prompt embeddings
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipeline.encode_prompt(
prompt,
None,
device,
True,
negative_prompt=None,
# lora_scale=None,
# clip_skip=None,
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def _encode_image(self, image: PIL.Image.Image, generator=None):
image_latents = self.pipeline.vae.encode(image)
image_latents = image_latents.latent_dist.sample(generator)
image_latents = self.pipeline.vae.config.scaling_factor * image_latents
return image_latents
def _process_added_kwargs(
self,
prompt_embeds,
pooled_prompt_embeds,
height=512,
width=512,
):
add_text_embeds = pooled_prompt_embeds
if self.pipeline.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.pipeline.text_encoder_2.config.projection_dim
add_time_ids = self.pipeline._get_add_time_ids(
(height, width),
(0, 0),
(height, width),
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
# Proprocess the text embeddings
added_cond_kwargs = {
"time_ids": add_time_ids.to(device=self.device),
"text_embeds": pooled_prompt_embeds.to(device=self.device),
}
return added_cond_kwargs
@torch.no_grad()
def _model_forward_pass(
self,
image,
prompt,
timestep=49,
guidance_scale=1.0,
num_inference_steps=50,
height=512,
width=512,
dtype=torch.float32,
batch_size=1,
generator=None,
):
# Set up timesteps
self.pipeline.scheduler.set_timesteps(num_inference_steps)
########################## Prepare latents ##########################
init_image = self.pipeline.image_processor.preprocess(
image,
height=height,
width=width,
# crops_coords=None,
# resize_mode="default"
)
init_image = init_image.to(dtype=torch.float32) # Make sure float 32 cause otherwise vae encoder doesnt work
init_image = init_image.to(device=self.device)
initial_image_latents = self._encode_image(init_image)
# Figure out the number fo steps to do
timestep = self.pipeline.scheduler.timesteps[timestep]
# Encode the prompt
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
prompt,
guidance_scale=guidance_scale,
device=self.device
)
# Proprocess the text embeddings
added_cond_kwargs = self._process_added_kwargs(
prompt_embeds,
pooled_prompt_embeds,
width=width,
height=height
)
# Add noise at the appropriate timescale
noise = randn_tensor(initial_image_latents.shape, device=torch.device(self.device), dtype=dtype)
noisy_latents = self.pipeline.scheduler.add_noise(initial_image_latents, noise, timestep.unsqueeze(0))
noisy_latents = self.pipeline.scheduler.scale_model_input(noisy_latents, timestep)
noisy_latents = noisy_latents.to(device=self.device, dtype=dtype)
########################## Run forward pass ##########################
noise_pred = self.pipeline.unet(
noisy_latents,
timestep,
encoder_hidden_states=prompt_embeds,
timestep_cond=None,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
########################## Get and save predicted image ##########################
# # Manually do the logic for the scheduler to get the original prediction
# s_churn = 0.0
# s_tmin = 0.0
# s_tmax = float("inf")
# s_noise = 1.0
# # Upcast to avoid precision issues when computing prev_sample
# sample = noisy_latents.to(torch.float32)
# sigma = self.pipeline.scheduler.sigmas[self.pipeline.scheduler.index_for_timestep(timestep)]
# gamma = min(s_churn / (len(self.pipeline.scheduler.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
# noise = randn_tensor(
# noise_pred.shape, dtype=noise_pred.dtype, device=noise_pred.device, generator=generator
# )
# eps = noise * s_noise
# sigma_hat = sigma * (gamma + 1)
# if gamma > 0:
# sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# pred_original_sample = sample - sigma_hat * noise_pred
# # For testing purposes get the predicted original latents and generate the image for it to verify that the image was encoded properly.
# image = self.pipeline.vae.decode(pred_original_sample / self.pipeline.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# image = self.pipeline.image_processor.postprocess(image, output_type="pil", do_denormalize=[True for _ in range(batch_size)])
return None
def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, num_samples=1, num_inference_steps=50, **kwargs):
# Cocnat the concepts into the prompt
modified_caption = caption+ "," + ", ".join([f"a {concept}" for concept in concepts])
# Run the forward pass with daam trace wrapper
concept_heatmaps = []
if num_samples > 1:
timesteps = [49 for _ in range(num_samples)]
# timesteps = list(range(num_samples))
else:
timesteps = [49]
all_heatmaps = []
for timestep in timesteps:
with trace(self.pipeline) as tc:
_ = self._model_forward_pass(
image,
modified_caption,
timestep=timestep,
guidance_scale=7.0,
num_inference_steps=num_inference_steps,
height=512,
width=512,
dtype=torch.float32,
batch_size=1,
)
print(f"Modified Caption: {modified_caption}")
heat_map = tc.compute_global_heat_map(prompt=modified_caption)
concept_heatmaps = []
# For each concept make a heatmap
for concept in concepts:
concept_heat_map = heat_map.compute_word_heat_map(concept).heatmap
concept_heatmaps.append(concept_heat_map)
concept_heatmaps = torch.stack(concept_heatmaps, dim=0)
all_heatmaps.append(concept_heatmaps)
all_heatmaps = torch.stack(all_heatmaps, dim=0)
all_heatmaps = all_heatmaps.mean(0)
return all_heatmaps, None |