|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from typing import Callable, List, Optional, Union |
|
|
|
import numpy as np |
|
import PIL.Image |
|
|
|
from ...utils import logging |
|
from .pipeline_stable_diffusion import StableDiffusionPipeline |
|
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline |
|
from .pipeline_stable_diffusion_inpaint_legacy import ( |
|
StableDiffusionInpaintPipelineLegacy, |
|
) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class StableDiffusionMegaPipeline(StableDiffusionPipeline): |
|
r""" |
|
Pipeline for generation using Stable Diffusion. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular xxxx, etc.) |
|
|
|
Args: |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder. Stable Diffusion uses the text portion of |
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] |
|
or [`DPMSolverMultistepScheduler`]. |
|
safety_checker ([`StableDiffusionSafetyChecker`]): |
|
Classification module that estimates whether generated images could be considered offensive or harmful. |
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. |
|
feature_extractor ([`CLIPFeatureExtractor`]): |
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
|
""" |
|
_optional_components = ["safety_checker", "feature_extractor"] |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.text2img(*args, **kwargs) |
|
|
|
def text2img( |
|
self, |
|
prompt: Union[str, List[str]], |
|
height: Optional[int] = 512, |
|
width: Optional[int] = 512, |
|
num_inference_steps: Optional[int] = 50, |
|
guidance_scale: Optional[float] = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: Optional[float] = 0.0, |
|
generator: Optional[np.random.RandomState] = None, |
|
latents: Optional[np.ndarray] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
): |
|
|
|
expected_components = inspect.signature(StableDiffusionPipeline.__init__).parameters.keys() |
|
components = {name: component for name, component in self.components.items() if name in expected_components} |
|
temp_pipeline = StableDiffusionPipeline( |
|
**components, requires_safety_checker=self.config.requires_safety_checker |
|
) |
|
output = temp_pipeline( |
|
prompt=prompt, |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
eta=eta, |
|
generator=generator, |
|
latents=latents, |
|
output_type=output_type, |
|
return_dict=return_dict, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
) |
|
return output |
|
|
|
def img2img( |
|
self, |
|
prompt: Union[str, List[str]], |
|
image: Union[np.ndarray, PIL.Image.Image], |
|
strength: float = 0.8, |
|
num_inference_steps: Optional[int] = 50, |
|
guidance_scale: Optional[float] = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: Optional[float] = 0.0, |
|
generator: Optional[np.random.RandomState] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
): |
|
expected_components = inspect.signature(StableDiffusionImg2ImgPipeline.__init__).parameters.keys() |
|
components = {name: component for name, component in self.components.items() if name in expected_components} |
|
temp_pipeline = StableDiffusionImg2ImgPipeline( |
|
**components, requires_safety_checker=self.config.requires_safety_checker |
|
) |
|
output = temp_pipeline( |
|
prompt=prompt, |
|
image=image, |
|
strength=strength, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
eta=eta, |
|
generator=generator, |
|
output_type=output_type, |
|
return_dict=return_dict, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
) |
|
|
|
return output |
|
|
|
def inpaint_legacy( |
|
self, |
|
prompt: Union[str, List[str]], |
|
image: Union[np.ndarray, PIL.Image.Image], |
|
mask_image: Union[np.ndarray, PIL.Image.Image], |
|
strength: float = 0.8, |
|
num_inference_steps: Optional[int] = 50, |
|
guidance_scale: Optional[float] = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: Optional[float] = 0.0, |
|
generator: Optional[np.random.RandomState] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
): |
|
expected_components = inspect.signature(StableDiffusionInpaintPipelineLegacy.__init__).parameters.keys() |
|
components = {name: component for name, component in self.components.items() if name in expected_components} |
|
temp_pipeline = StableDiffusionInpaintPipelineLegacy( |
|
**components, requires_safety_checker=self.config.requires_safety_checker |
|
) |
|
output = temp_pipeline( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask_image, |
|
strength=strength, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
eta=eta, |
|
generator=generator, |
|
output_type=output_type, |
|
return_dict=return_dict, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
) |
|
|
|
return output |
|
|