import inspect from typing import Callable, List, Optional, Union import PIL import torch import torchvision from transformers import CLIPTextModel, CLIPTokenizer from diffusers.utils import is_accelerate_available from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers import DiffusionPipeline from diffusers.schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler from diffusers.utils import deprecate from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from packaging import version import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class MGDPipe(DiffusionPipeline): _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker=None, feature_extractor=None, requires_safety_checker: bool = False, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly." ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True." ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when using safety checker." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64. Please update the config accordingly." ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker.vision_model, device) @property def _execution_device(self): if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) text_embeddings = text_embeddings[0] bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings def prepare_extra_step_kwargs(self, generator, eta): accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {"eta": eta} if accepts_eta else {} accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def decode_latents(self, latents): latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs(self, prompt, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: rand_device = "cpu" if device.type == "mps" else device if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) if isinstance(generator, list): masked_image_latents = [ self.vae.encode(masked_image[i: i + 1]).latent_dist.sample(generator=generator[i]) for i in range(batch_size) ] masked_image_latents = torch.cat(masked_image_latents, dim=0) else: masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) masked_image_latents = 0.18215 * masked_image_latents if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents @torch.no_grad() def __call__(self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image], pose_map: torch.FloatTensor, sketch: torch.FloatTensor, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, sketch_cond_rate: float = 1.0, start_cond_rate: float = 0, no_pose: bool = False, ): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor self.check_inputs(prompt, height, width, callback_steps) batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 text_embeddings = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) mask, masked_image_latents = self.prepare_mask_latents( mask=mask_image, masked_image=image, batch_size=batch_size * num_images_per_prompt, height=height, width=width, dtype=text_embeddings.dtype, device=device, generator=generator, do_classifier_free_guidance=do_classifier_free_guidance, ) pose_map = torch.nn.functional.interpolate( pose_map, size=(pose_map.shape[2] // 8, pose_map.shape[3] // 8), mode="bilinear" ) if no_pose: pose_map = torch.zeros_like(pose_map) sketch = torchvision.transforms.functional.resize( sketch, size=(sketch.shape[2] // 8, sketch.shape[3] // 8), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True, ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps start_cond_step = int(num_inference_steps * start_cond_rate) sketch_start = start_cond_step sketch_end = sketch_cond_rate * num_inference_steps + start_cond_step num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, latents, ) pose_map = torch.cat([torch.zeros_like(pose_map), pose_map]) if do_classifier_free_guidance else pose_map sketch = torch.cat([torch.zeros_like(sketch), sketch]) if do_classifier_free_guidance else sketch num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] num_channels_pose_map = pose_map.shape[1] num_channels_sketch = sketch.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image + num_channels_pose_map + num_channels_sketch != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image} +" f" `num_channels_pose_map`: {num_channels_pose_map} + `num_channels_sketch`: {num_channels_sketch}. Please" " verify the config of `pipeline.unet` or your `mask_image`, `image`, or `pose_map` input." ) # Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Run the pipeline num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if i < sketch_start or i > sketch_end: local_sketch = torch.zeros_like(sketch) else: local_sketch = sketch latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat( [latent_model_input, mask, masked_image_latents, pose_map.to(mask.dtype), local_sketch.to(mask.dtype)], dim=1, ) noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.to(self.vae.dtype) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # Decode latents to images image = self.decode_latents(latents) # Convert to PIL format if required if output_type == "pil": image = self.numpy_to_pil(image) # Return final output if not return_dict: return (image, None) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)