File size: 10,647 Bytes
188fd89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import inspect
from typing import List, Optional, Union, Tuple
import numpy as np
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
PNDMScheduler,
LMSDiscreteScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
def preprocess_init_image(image: Image, width: int, height: int):
image = image.resize((width, height), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask: Image, width: int, height: int):
mask = mask.convert("L")
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = torch.from_numpy(mask)
return mask
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
From https://github.com/huggingface/diffusers/pull/241
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Optional[torch.FloatTensor],
mask: Optional[torch.FloatTensor],
width: int,
height: int,
prompt_strength: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
) -> Image:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if prompt_strength < 0 or prompt_strength > 1:
raise ValueError(
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}"
)
if mask is not None and init_image is None:
raise ValueError(
"If mask is defined, then init_image also needs to be defined"
)
if width % 8 != 0 or height % 8 != 0:
raise ValueError("Width and height must both be divisible by 8")
# set timesteps
accepts_offset = "offset" in set(
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
)
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
if init_image is not None:
init_latents_orig, latents, init_timestep = self.latents_from_init_image(
init_image,
prompt_strength,
offset,
num_inference_steps,
batch_size,
generator,
)
else:
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=self.device,
)
init_timestep = num_inference_steps
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self.embed_text(
prompt, do_classifier_free_guidance, batch_size
)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[
"prev_sample"
]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
"prev_sample"
]
# replace the unmasked part with original latents, with added noise
if mask is not None:
timesteps = self.scheduler.timesteps[t_start + i]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)
latents = noisy_init_latents * mask + latents * (1 - mask)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="pt"
).to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_cheker_input.pixel_values
)
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
def latents_from_init_image(
self,
init_image: torch.FloatTensor,
prompt_strength: float,
offset: int,
num_inference_steps: int,
batch_size: int,
generator: Optional[torch.Generator],
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample()
init_latents = 0.18215 * init_latents
init_latents_orig = init_latents
# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * prompt_strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
return init_latents_orig, init_latents, init_timestep
def embed_text(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool,
batch_size: int,
) -> torch.FloatTensor:
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(self.device)
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
|