import inspect from typing import List, Optional, Union, Tuple import numpy as np import torch from PIL import Image from diffusers import ( AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, LMSDiscreteScheduler, UNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer def preprocess_init_image(image: Image, width: int, height: int): image = image.resize((width, height), resample=Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 def preprocess_mask(mask: Image, width: int, height: int): mask = mask.convert("L") mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = torch.from_numpy(mask) return mask class StableDiffusionImg2ImgPipeline(DiffusionPipeline): """ From https://github.com/huggingface/diffusers/pull/241 """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], init_image: Optional[torch.FloatTensor], mask: Optional[torch.FloatTensor], width: int, height: int, prompt_strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, eta: float = 0.0, generator: Optional[torch.Generator] = None, ) -> Image: if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError( f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" ) if prompt_strength < 0 or prompt_strength > 1: raise ValueError( f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}" ) if mask is not None and init_image is None: raise ValueError( "If mask is defined, then init_image also needs to be defined" ) if width % 8 != 0 or height % 8 != 0: raise ValueError("Width and height must both be divisible by 8") # set timesteps accepts_offset = "offset" in set( inspect.signature(self.scheduler.set_timesteps).parameters.keys() ) extra_set_kwargs = {} offset = 0 if accepts_offset: offset = 1 extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) if init_image is not None: init_latents_orig, latents, init_timestep = self.latents_from_init_image( init_image, prompt_strength, offset, num_inference_steps, batch_size, generator, ) else: latents = torch.randn( (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, device=self.device, ) init_timestep = num_inference_steps do_classifier_free_guidance = guidance_scale > 1.0 text_embeddings = self.embed_text( prompt, do_classifier_free_guidance, batch_size ) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set( inspect.signature(self.scheduler.step).parameters.keys() ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta mask_noise = torch.randn(latents.shape, generator=generator, device=self.device) # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if do_classifier_free_guidance else latents ) if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings )["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[ "prev_sample" ] else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[ "prev_sample" ] # replace the unmasked part with original latents, with added noise if mask is not None: timesteps = self.scheduler.timesteps[t_start + i] timesteps = torch.tensor( [timesteps] * batch_size, dtype=torch.long, device=self.device ) noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps) latents = noisy_init_latents * mask + latents * (1 - mask) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker safety_cheker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="pt" ).to(self.device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_cheker_input.pixel_values ) image = self.numpy_to_pil(image) return {"sample": image, "nsfw_content_detected": has_nsfw_concept} def latents_from_init_image( self, init_image: torch.FloatTensor, prompt_strength: float, offset: int, num_inference_steps: int, batch_size: int, generator: Optional[torch.Generator], ) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]: # encode the init image into latents and scale the latents init_latents = self.vae.encode(init_image.to(self.device)).sample() init_latents = 0.18215 * init_latents init_latents_orig = init_latents # prepare init_latents noise to latents init_latents = torch.cat([init_latents] * batch_size) # get the original timestep using init_timestep init_timestep = int(num_inference_steps * prompt_strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor( [timesteps] * batch_size, dtype=torch.long, device=self.device ) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) return init_latents_orig, init_latents, init_timestep def embed_text( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool, batch_size: int, ) -> torch.FloatTensor: # get prompt text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt", ) uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(self.device) )[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings