|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from typing import Callable, List, Optional, Union |
|
|
|
import numpy as np |
|
import paddle |
|
import PIL |
|
from packaging import version |
|
|
|
from paddlenlp.transformers import ( |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
DPTForDepthEstimation, |
|
DPTImageProcessor, |
|
) |
|
|
|
from ...configuration_utils import FrozenDict |
|
from ...models import AutoencoderKL, UNet2DConditionModel |
|
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
|
from ...schedulers import ( |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler, |
|
) |
|
from ...utils import PIL_INTERPOLATION, deprecate, logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
def preprocess(image): |
|
if isinstance(image, paddle.Tensor): |
|
return image |
|
elif isinstance(image, PIL.Image.Image): |
|
image = [image] |
|
|
|
if isinstance(image[0], PIL.Image.Image): |
|
w, h = image[0].size |
|
w, h = map(lambda x: x - x % 32, (w, h)) |
|
|
|
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] |
|
image = np.concatenate(image, axis=0) |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = image.transpose(0, 3, 1, 2) |
|
image = 2.0 * image - 1.0 |
|
image = paddle.to_tensor(image) |
|
elif isinstance(image[0], paddle.Tensor): |
|
image = paddle.concat(image, axis=0) |
|
return image |
|
|
|
|
|
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): |
|
r""" |
|
Pipeline for text-guided image to image generation using Stable Diffusion. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder. Stable Diffusion uses the text portion of |
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: Union[ |
|
DDIMScheduler, |
|
PNDMScheduler, |
|
LMSDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
DPMSolverMultistepScheduler, |
|
], |
|
depth_estimator: DPTForDepthEstimation, |
|
feature_extractor: DPTImageProcessor, |
|
): |
|
super().__init__() |
|
|
|
is_unet_version_less_0_9_0 = hasattr(unet.config, "_ppdiffusers_version") and version.parse( |
|
version.parse(unet.config._ppdiffusers_version).base_version |
|
) < version.parse("0.9.0.dev0") |
|
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 |
|
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: |
|
deprecation_message = ( |
|
"The configuration file of the unet has set the default `sample_size` to smaller than" |
|
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" |
|
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" |
|
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" |
|
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" |
|
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" |
|
" in the config might lead to incorrect results in future versions. If you have downloaded this" |
|
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" |
|
" the `unet/config.json` file" |
|
) |
|
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(unet.config) |
|
new_config["sample_size"] = 64 |
|
unet._internal_dict = FrozenDict(new_config) |
|
|
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
depth_estimator=depth_estimator, |
|
feature_extractor=feature_extractor, |
|
) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
|
|
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `list(int)`): |
|
prompt to be encoded |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
do_classifier_free_guidance (`bool`): |
|
whether to use classifier free guidance or not |
|
negative_prompt (`str` or `List[str]`): |
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored |
|
if `guidance_scale` is less than `1`). |
|
""" |
|
batch_size = len(prompt) if isinstance(prompt, list) else 1 |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pd", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all( |
|
text_input_ids, untruncated_ids |
|
): |
|
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
|
|
config = ( |
|
self.text_encoder.config |
|
if isinstance(self.text_encoder.config, dict) |
|
else self.text_encoder.config.to_dict() |
|
) |
|
if config.get("use_attention_mask", None) is not None and config["use_attention_mask"]: |
|
attention_mask = text_inputs.attention_mask |
|
else: |
|
attention_mask = None |
|
|
|
text_embeddings = self.text_encoder( |
|
text_input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
text_embeddings = text_embeddings[0] |
|
|
|
|
|
bs_embed, seq_len, _ = text_embeddings.shape |
|
text_embeddings = text_embeddings.tile([1, num_images_per_prompt, 1]) |
|
text_embeddings = text_embeddings.reshape([bs_embed * num_images_per_prompt, seq_len, -1]) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
uncond_tokens: List[str] |
|
if negative_prompt is None: |
|
uncond_tokens = [""] * batch_size |
|
elif type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif isinstance(negative_prompt, str): |
|
uncond_tokens = [negative_prompt] |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
else: |
|
uncond_tokens = negative_prompt |
|
|
|
max_length = text_input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
uncond_tokens, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="pd", |
|
) |
|
|
|
if config.get("use_attention_mask", None) is not None and config["use_attention_mask"]: |
|
attention_mask = uncond_input.attention_mask |
|
else: |
|
attention_mask = None |
|
|
|
uncond_embeddings = self.text_encoder( |
|
uncond_input.input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
uncond_embeddings = uncond_embeddings[0] |
|
|
|
|
|
seq_len = uncond_embeddings.shape[1] |
|
uncond_embeddings = uncond_embeddings.tile([1, num_images_per_prompt, 1]) |
|
uncond_embeddings = uncond_embeddings.reshape([batch_size * num_images_per_prompt, seq_len, -1]) |
|
|
|
|
|
|
|
|
|
text_embeddings = paddle.concat([uncond_embeddings, text_embeddings]) |
|
|
|
return text_embeddings |
|
|
|
|
|
def run_safety_checker(self, image, dtype): |
|
if self.safety_checker is not None: |
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd") |
|
image, has_nsfw_concept = self.safety_checker( |
|
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype) |
|
) |
|
else: |
|
has_nsfw_concept = None |
|
return image, has_nsfw_concept |
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / 0.18215 * latents |
|
image = self.vae.decode(latents).sample |
|
image = (image / 2 + 0.5).clip(0, 1) |
|
|
|
image = image.transpose([0, 2, 3, 1]).cast("float32").numpy() |
|
return image |
|
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
def check_inputs(self, prompt, strength, callback_steps): |
|
if not isinstance(prompt, str) and not isinstance(prompt, list): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
if strength < 0 or strength > 1: |
|
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") |
|
|
|
if (callback_steps is None) or ( |
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) |
|
): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
|
|
def get_timesteps(self, num_inference_steps, strength): |
|
|
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
|
|
|
t_start = max(num_inference_steps - init_timestep, 0) |
|
timesteps = self.scheduler.timesteps[t_start:] |
|
|
|
return timesteps, num_inference_steps - t_start |
|
|
|
|
|
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): |
|
image = image.cast(dtype=dtype) |
|
|
|
batch_size = batch_size * num_images_per_prompt |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if isinstance(generator, list): |
|
init_latents = [ |
|
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) |
|
] |
|
init_latents = paddle.concat(init_latents, axis=0) |
|
else: |
|
init_latents = self.vae.encode(image).latent_dist.sample(generator) |
|
init_latents = 0.18215 * init_latents |
|
|
|
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: |
|
|
|
deprecation_message = ( |
|
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" |
|
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note" |
|
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" |
|
" your script to pass as many initial images as text prompts to suppress this warning." |
|
) |
|
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) |
|
additional_image_per_prompt = batch_size // init_latents.shape[0] |
|
init_latents = paddle.concat([init_latents] * additional_image_per_prompt, axis=0) |
|
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: |
|
raise ValueError( |
|
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
|
) |
|
else: |
|
init_latents = paddle.concat([init_latents], axis=0) |
|
|
|
shape = init_latents.shape |
|
if isinstance(generator, list): |
|
shape = [ |
|
1, |
|
] + shape[1:] |
|
noise = [paddle.randn(shape, generator=generator[i], dtype=dtype) for i in range(batch_size)] |
|
noise = paddle.concat(noise, axis=0) |
|
else: |
|
noise = paddle.randn(shape, generator=generator, dtype=dtype) |
|
|
|
|
|
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) |
|
latents = init_latents |
|
|
|
return latents |
|
|
|
def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype): |
|
if isinstance(image, PIL.Image.Image): |
|
image = [image] |
|
else: |
|
image = [img for img in image] |
|
|
|
if isinstance(image[0], PIL.Image.Image): |
|
width, height = image[0].size |
|
else: |
|
width, height = image[0].shape[-2:] |
|
|
|
if depth_map is None: |
|
pixel_values = self.feature_extractor(images=image, return_tensors="pd").pixel_values |
|
|
|
|
|
depth_map = self.depth_estimator(pixel_values).predicted_depth |
|
else: |
|
depth_map = depth_map.cast(dtype) |
|
|
|
depth_map = paddle.nn.functional.interpolate( |
|
depth_map.unsqueeze(1), |
|
size=(height // self.vae_scale_factor, width // self.vae_scale_factor), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
|
|
depth_min = paddle.amin(depth_map, axis=[1, 2, 3], keepdim=True) |
|
depth_max = paddle.amax(depth_map, axis=[1, 2, 3], keepdim=True) |
|
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 |
|
depth_map = depth_map.cast(dtype) |
|
|
|
|
|
if depth_map.shape[0] < batch_size: |
|
depth_map = depth_map.tile([batch_size, 1, 1, 1]) |
|
|
|
depth_map = paddle.concat([depth_map] * 2) if do_classifier_free_guidance else depth_map |
|
return depth_map |
|
|
|
@paddle.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
image: Union[paddle.Tensor, PIL.Image.Image], |
|
depth_map: Optional[paddle.Tensor] = None, |
|
strength: float = 0.8, |
|
num_inference_steps: Optional[int] = 50, |
|
guidance_scale: Optional[float] = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: Optional[float] = 0.0, |
|
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`): |
|
The prompt or prompts to guide the image generation. |
|
image (`paddle.Tensor` or `PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch, that will be used as the starting point for the |
|
process. |
|
strength (`float`, *optional*, defaults to 0.8): |
|
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` |
|
will be used as a starting point, adding more noise to it the larger the `strength`. The number of |
|
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will |
|
be maximum and the denoising process will run for the full number of iterations specified in |
|
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. This parameter will be modulated by `strength`. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored |
|
if `guidance_scale` is less than `1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.pipeline_utils.ImagePipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
|
|
Returns: |
|
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: |
|
[`~pipelines.pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
|
When returning a tuple, the first element is a list with the generated images. |
|
""" |
|
|
|
self.check_inputs(prompt, strength, callback_steps) |
|
|
|
|
|
batch_size = 1 if isinstance(prompt, str) else len(prompt) |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_embeddings = self._encode_prompt( |
|
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt |
|
) |
|
|
|
|
|
depth_mask = self.prepare_depth_map( |
|
image, |
|
depth_map, |
|
batch_size * num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
text_embeddings.dtype, |
|
) |
|
|
|
|
|
image = preprocess(image) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) |
|
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt]) |
|
|
|
|
|
latents = self.prepare_latents( |
|
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, generator |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
latent_model_input = paddle.concat([latent_model_input, depth_mask], axis=1) |
|
|
|
|
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
|
|
image = self.decode_latents(latents) |
|
|
|
|
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return ImagePipelineOutput(images=image) |
|
|