Material / image_to_image.py
clone3's picture
Upload 4 files
188fd89 verified
raw
history blame
10.6 kB
import inspect
from typing import List, Optional, Union, Tuple
import numpy as np
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
PNDMScheduler,
LMSDiscreteScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
def preprocess_init_image(image: Image, width: int, height: int):
image = image.resize((width, height), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask: Image, width: int, height: int):
mask = mask.convert("L")
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = torch.from_numpy(mask)
return mask
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
From https://github.com/huggingface/diffusers/pull/241
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Optional[torch.FloatTensor],
mask: Optional[torch.FloatTensor],
width: int,
height: int,
prompt_strength: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
) -> Image:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if prompt_strength < 0 or prompt_strength > 1:
raise ValueError(
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}"
)
if mask is not None and init_image is None:
raise ValueError(
"If mask is defined, then init_image also needs to be defined"
)
if width % 8 != 0 or height % 8 != 0:
raise ValueError("Width and height must both be divisible by 8")
# set timesteps
accepts_offset = "offset" in set(
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
)
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
if init_image is not None:
init_latents_orig, latents, init_timestep = self.latents_from_init_image(
init_image,
prompt_strength,
offset,
num_inference_steps,
batch_size,
generator,
)
else:
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=self.device,
)
init_timestep = num_inference_steps
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self.embed_text(
prompt, do_classifier_free_guidance, batch_size
)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[
"prev_sample"
]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
"prev_sample"
]
# replace the unmasked part with original latents, with added noise
if mask is not None:
timesteps = self.scheduler.timesteps[t_start + i]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)
latents = noisy_init_latents * mask + latents * (1 - mask)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="pt"
).to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_cheker_input.pixel_values
)
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
def latents_from_init_image(
self,
init_image: torch.FloatTensor,
prompt_strength: float,
offset: int,
num_inference_steps: int,
batch_size: int,
generator: Optional[torch.Generator],
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample()
init_latents = 0.18215 * init_latents
init_latents_orig = init_latents
# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * prompt_strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor(
[timesteps] * batch_size, dtype=torch.long, device=self.device
)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
return init_latents_orig, init_latents, init_timestep
def embed_text(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool,
batch_size: int,
) -> torch.FloatTensor:
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(self.device)
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings