|
import inspect |
|
from typing import List, Optional, Union, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDIMScheduler, |
|
DiffusionPipeline, |
|
PNDMScheduler, |
|
LMSDiscreteScheduler, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
def preprocess_init_image(image: Image, width: int, height: int): |
|
image = image.resize((width, height), resample=Image.LANCZOS) |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = image[None].transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image) |
|
return 2.0 * image - 1.0 |
|
|
|
|
|
def preprocess_mask(mask: Image, width: int, height: int): |
|
mask = mask.convert("L") |
|
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS) |
|
mask = np.array(mask).astype(np.float32) / 255.0 |
|
mask = np.tile(mask, (4, 1, 1)) |
|
mask = mask[None].transpose(0, 1, 2, 3) |
|
mask = torch.from_numpy(mask) |
|
return mask |
|
|
|
|
|
class StableDiffusionImg2ImgPipeline(DiffusionPipeline): |
|
""" |
|
From https://github.com/huggingface/diffusers/pull/241 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], |
|
safety_checker: StableDiffusionSafetyChecker, |
|
feature_extractor: CLIPFeatureExtractor, |
|
): |
|
super().__init__() |
|
scheduler = scheduler.set_format("pt") |
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
init_image: Optional[torch.FloatTensor], |
|
mask: Optional[torch.FloatTensor], |
|
width: int, |
|
height: int, |
|
prompt_strength: float = 0.8, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
eta: float = 0.0, |
|
generator: Optional[torch.Generator] = None, |
|
) -> Image: |
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
elif isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
raise ValueError( |
|
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" |
|
) |
|
|
|
if prompt_strength < 0 or prompt_strength > 1: |
|
raise ValueError( |
|
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}" |
|
) |
|
|
|
if mask is not None and init_image is None: |
|
raise ValueError( |
|
"If mask is defined, then init_image also needs to be defined" |
|
) |
|
|
|
if width % 8 != 0 or height % 8 != 0: |
|
raise ValueError("Width and height must both be divisible by 8") |
|
|
|
|
|
accepts_offset = "offset" in set( |
|
inspect.signature(self.scheduler.set_timesteps).parameters.keys() |
|
) |
|
extra_set_kwargs = {} |
|
offset = 0 |
|
if accepts_offset: |
|
offset = 1 |
|
extra_set_kwargs["offset"] = 1 |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) |
|
|
|
if init_image is not None: |
|
init_latents_orig, latents, init_timestep = self.latents_from_init_image( |
|
init_image, |
|
prompt_strength, |
|
offset, |
|
num_inference_steps, |
|
batch_size, |
|
generator, |
|
) |
|
else: |
|
latents = torch.randn( |
|
(batch_size, self.unet.in_channels, height // 8, width // 8), |
|
generator=generator, |
|
device=self.device, |
|
) |
|
init_timestep = num_inference_steps |
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
text_embeddings = self.embed_text( |
|
prompt, do_classifier_free_guidance, batch_size |
|
) |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set( |
|
inspect.signature(self.scheduler.step).parameters.keys() |
|
) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device) |
|
|
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
latents = latents * self.scheduler.sigmas[0] |
|
|
|
t_start = max(num_inference_steps - init_timestep + offset, 0) |
|
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): |
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
) |
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
sigma = self.scheduler.sigmas[i] |
|
latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, t, encoder_hidden_states=text_embeddings |
|
)["sample"] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[ |
|
"prev_sample" |
|
] |
|
else: |
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[ |
|
"prev_sample" |
|
] |
|
|
|
|
|
if mask is not None: |
|
timesteps = self.scheduler.timesteps[t_start + i] |
|
timesteps = torch.tensor( |
|
[timesteps] * batch_size, dtype=torch.long, device=self.device |
|
) |
|
noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps) |
|
latents = noisy_init_latents * mask + latents * (1 - mask) |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
image = self.vae.decode(latents) |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
|
|
|
safety_cheker_input = self.feature_extractor( |
|
self.numpy_to_pil(image), return_tensors="pt" |
|
).to(self.device) |
|
image, has_nsfw_concept = self.safety_checker( |
|
images=image, clip_input=safety_cheker_input.pixel_values |
|
) |
|
|
|
image = self.numpy_to_pil(image) |
|
|
|
return {"sample": image, "nsfw_content_detected": has_nsfw_concept} |
|
|
|
def latents_from_init_image( |
|
self, |
|
init_image: torch.FloatTensor, |
|
prompt_strength: float, |
|
offset: int, |
|
num_inference_steps: int, |
|
batch_size: int, |
|
generator: Optional[torch.Generator], |
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]: |
|
|
|
init_latents = self.vae.encode(init_image.to(self.device)).sample() |
|
init_latents = 0.18215 * init_latents |
|
init_latents_orig = init_latents |
|
|
|
|
|
init_latents = torch.cat([init_latents] * batch_size) |
|
|
|
|
|
init_timestep = int(num_inference_steps * prompt_strength) + offset |
|
init_timestep = min(init_timestep, num_inference_steps) |
|
timesteps = self.scheduler.timesteps[-init_timestep] |
|
timesteps = torch.tensor( |
|
[timesteps] * batch_size, dtype=torch.long, device=self.device |
|
) |
|
|
|
|
|
noise = torch.randn(init_latents.shape, generator=generator, device=self.device) |
|
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) |
|
|
|
return init_latents_orig, init_latents, init_timestep |
|
|
|
def embed_text( |
|
self, |
|
prompt: Union[str, List[str]], |
|
do_classifier_free_guidance: bool, |
|
batch_size: int, |
|
) -> torch.FloatTensor: |
|
|
|
text_input = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] |
|
|
|
|
|
|
|
|
|
|
|
if do_classifier_free_guidance: |
|
max_length = text_input.input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
[""] * batch_size, |
|
padding="max_length", |
|
max_length=max_length, |
|
return_tensors="pt", |
|
) |
|
uncond_embeddings = self.text_encoder( |
|
uncond_input.input_ids.to(self.device) |
|
)[0] |
|
|
|
|
|
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
return text_embeddings |
|
|