diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c6b73e65a1f431318adf7f92b0d4fbb37464af --- /dev/null +++ b/app.py @@ -0,0 +1,99 @@ +import os +import pdb +from PIL import Image +import gradio as gr + +from src.utils.gradio_utils import * + + +if __name__=="__main__": + step_dict = {'800': 800, '900': 900, '1000': 1000, '1100': 1100} + + with gr.Blocks(css=CSS_main) as demo: + gr.HTML(HTML_header) + + with gr.Row(): + # col A: Optimize personalized embedding + with gr.Column(scale=2) as gc_left: + gr.HTML("

[Step 1] Optimize personalized embedding

") + img_in_real = gr.Image(type="pil", label="Start by uploading the source image", elem_id="input_image").style(height=300, width=300) + gr.Examples( examples="src_image", inputs=[img_in_real]) + prompt = gr.Textbox(value="a standing dog", label="Source text prompt (Describe the source image)", interactive=True) + n_hiper = gr.Slider(5, 10, 5, label="Number of personalized embedding (Tips! 5 for animals / 10 for humans)", interactive=True, step=1) + + btn_optimize = gr.Button("Optimize", label="") + fpath_z_gen = gr.Textbox(value="placeholder", visible=False) + + gr.HTML("

See the [Step 1] results with different optimization steps

") + with gr.Row(): + with gr.Column(scale=0.3, min_width=0.7) as gc_left: + # btn_source = gr.Button("Source image", label="") + btn_opt_step800 = gr.Button("Step 800", label="") + btn_opt_step900 = gr.Button("Step 900", label="") + btn_opt_step1000 = gr.Button("Step 1000", label="") + btn_opt_step1100 = gr.Button("Step 1100", label="") + with gr.Column(scale=0.5, min_width=0.8) as gc_left: + img_src = gr.Image(type="pil", label="Source image", visible=True).style(height=250, width=250) + with gr.Column(scale=0.5, min_width=0.8) as gc_left: + img_out_opt = gr.Image(type="pil", label="Optimization step output", visible=True).style(height=250, width=250) + + + # col B: Generate target image + with gr.Column(scale=2) as gc_left: + + gr.HTML("

[Step 2] Generate target image

") + with gr.Row(): + + with gr.Column(): + dest = gr.Textbox(value="a sitting dog", label="Target text prompt", interactive=True) + step = gr.Radio(["Step 800", "Step 900", "Step 1000", "Step 1100"], value="Step 1000", label="Training optimization step \n (Refer to the personalized results corresponding to each optimization step listed in the left column.)") + seed = gr.Number(value=111111, label="Random seed", interactive=True) + with gr.Row(): + btn_generate = gr.Button("Generate", label="") + img_out = gr.Image(type="pil", label="Output Image", visible=True) + + with gr.Accordion("Instruction", open=True): + gr.Textbox("In NVIDIA GeForce GTX 3090, [step 1] takes about 4 minutes and [step 2] takes about 1 minute.", show_label=False) + gr.Textbox("At [step 1], put the desired source image and write the source text that describes the source image. If it is difficult to describe, you can use a noun such as 'a dog' or 'a woman.' Then decide on the number of desired personalized embeddings.", show_label=False) + gr.Textbox("After [step 1], you can check the personalized results with different optimization steps and select the optimization step. First, check if the image at step 1000 has a subject similar to the source image. In the paper, we use the 1000 step for optimization almost.", show_label=False) + gr.Textbox("At [step 2], write the derised target text. Then, refer to the generated personalized image in the bottom left and choose an optimization. If the desired image is not obtained, try another random seed.", show_label=False) + + + ############ + btn_optimize.click(launch_optimize, [img_in_real, prompt, n_hiper], [fpath_z_gen, img_src]) + def fn_set_none(): + return gr.update(value=None) + btn_optimize.click(fn_set_none, [], img_in_real) + # btn_optimize.click(set_visible_true, [], img_in_synth) + btn_optimize.click(set_visible_false, [], img_in_real) + + + ############ + def fn_clear_all(): + return gr.update(value=None), gr.update(value=None), gr.update(value=None) + + img_in_real.clear(fn_clear_all, [], [img_out, img_in_real])#, img_in_synth]) + # img_in_real.clear(set_visible_true, [], img_in_synth) + img_in_real.clear(set_visible_false, [], img_in_real) + + img_out.clear(fn_clear_all, [], [img_out, img_in_real])#, img_in_synth]) + + + ############ + btn_generate.click(launch_main, + [ + dest, step, + fpath_z_gen, seed, + ], + [img_out] + ) + ############ + btn_opt_step800.click(launch_opt800, [],[img_out_opt]) + btn_opt_step900.click(launch_opt900, [],[img_out_opt]) + btn_opt_step1000.click(launch_opt1000, [],[img_out_opt]) + btn_opt_step1100.click(launch_opt1100, [],[img_out_opt]) + gr.HTML("
") + + gr.close_all() + demo.queue(concurrency_count=1) + demo.launch(server_port=2222, server_name="0.0.0.0", debug=True,share=True) diff --git a/src/diffusers_/__init__.py b/src/diffusers_/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fed1145b143677eec35dd5da7b30d1738724bd --- /dev/null +++ b/src/diffusers_/__init__.py @@ -0,0 +1,15 @@ +from .utils import ( + is_torch_available, + is_transformers_available, +) + + +__version__ = "0.9.0" + + +if is_torch_available() and is_transformers_available(): + from .stable_diffusion import ( + StableDiffusionPipeline, + ) +else: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers_/__pycache__/__init__.cpython-310.pyc b/src/diffusers_/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1544128cfc2bbbda15cc5992e55c4a80fa8930f Binary files /dev/null and b/src/diffusers_/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/__init__.cpython-37.pyc b/src/diffusers_/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b94c3ca8eb979e87415a4e15bc2745bcd4ef776 Binary files /dev/null and b/src/diffusers_/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/__init__.cpython-38.pyc b/src/diffusers_/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e19f0e6def8a1fd31644cfe7313a5850a6fb9fc Binary files /dev/null and b/src/diffusers_/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc b/src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f19c99ad5bb070747e9414c7f1177590b8a429b Binary files /dev/null and b/src/diffusers_/__pycache__/configuration_utils.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc b/src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9baccc41892edcd93c2b74b08be5a341c08f865 Binary files /dev/null and b/src/diffusers_/__pycache__/configuration_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc b/src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7505390854e1fd51518e29f8f9ba6b8158daf87 Binary files /dev/null and b/src/diffusers_/__pycache__/configuration_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc b/src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f3992b660375da9e4985a7c406bd9290018665 Binary files /dev/null and b/src/diffusers_/__pycache__/dependency_versions_check.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc b/src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f2bbd6acaae96c35d4b55be21bce4399c6c26c Binary files /dev/null and b/src/diffusers_/__pycache__/dependency_versions_table.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc b/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56c8b4527481f39dd50b22987deb18b26c6629b5 Binary files /dev/null and b/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc b/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cfe98f51580a8b9b4e978e4e4fead1ba96d4fea Binary files /dev/null and b/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc b/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b791b8ad3fd3d27a92f01ba8c6c49c7b9b12bfb7 Binary files /dev/null and b/src/diffusers_/__pycache__/dynamic_modules_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/hub_utils.cpython-310.pyc b/src/diffusers_/__pycache__/hub_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46757de9ca852a0c161c7484dbc6e2dae1ec3bb8 Binary files /dev/null and b/src/diffusers_/__pycache__/hub_utils.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/hub_utils.cpython-37.pyc b/src/diffusers_/__pycache__/hub_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a81d7ab8a7473f3b26d95088c73298c577a49a0e Binary files /dev/null and b/src/diffusers_/__pycache__/hub_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/hub_utils.cpython-38.pyc b/src/diffusers_/__pycache__/hub_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2022cd818bbdbb6f0ca5a054258ce417d106901 Binary files /dev/null and b/src/diffusers_/__pycache__/hub_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc b/src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76b50a2990f44ae8defcc4e82d98ae40fdff181f Binary files /dev/null and b/src/diffusers_/__pycache__/modeling_flax_pytorch_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc b/src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a823585fa1dda56268808d6a7ff1ffa3f071cb Binary files /dev/null and b/src/diffusers_/__pycache__/modeling_flax_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc b/src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..069613eef8a397e9fd81d3831bd78b84e509950a Binary files /dev/null and b/src/diffusers_/__pycache__/modeling_utils.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc b/src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54c4e9fae51d2e025e353bc5057fa9bf35607b75 Binary files /dev/null and b/src/diffusers_/__pycache__/modeling_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc b/src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f82d3db63018cca23efb4e34bdfe27209a33af80 Binary files /dev/null and b/src/diffusers_/__pycache__/modeling_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc b/src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52b81c4e9d4bcbc8f0c51d81535c9fa0534ce782 Binary files /dev/null and b/src/diffusers_/__pycache__/onnx_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc b/src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6203d4243e41edd1d2054c7a8d8087f48223c64 Binary files /dev/null and b/src/diffusers_/__pycache__/onnx_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/optimization.cpython-37.pyc b/src/diffusers_/__pycache__/optimization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fba622a3e4be88c9c26f63135fc2a7a4264c771 Binary files /dev/null and b/src/diffusers_/__pycache__/optimization.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/optimization.cpython-38.pyc b/src/diffusers_/__pycache__/optimization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7070d43b13dd5dbacde4e72680a43d8b44167338 Binary files /dev/null and b/src/diffusers_/__pycache__/optimization.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc b/src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb6d2f7adba02db5ff00f45004a896daeb8acdef Binary files /dev/null and b/src/diffusers_/__pycache__/pipeline_flax_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc b/src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da7da3305bb20255b7030342f007230d8c4f02cc Binary files /dev/null and b/src/diffusers_/__pycache__/pipeline_utils.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc b/src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8502617bf545c68c210a7b54b42dfff2a5efc75 Binary files /dev/null and b/src/diffusers_/__pycache__/pipeline_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc b/src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b3334533f11f4f1b18cf1945378986c4f1d9d3 Binary files /dev/null and b/src/diffusers_/__pycache__/pipeline_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc b/src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0989fedea44f0df2a710d6b80ab1eea93df6e460 Binary files /dev/null and b/src/diffusers_/__pycache__/scheduling_utils.cpython-310.pyc differ diff --git a/src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc b/src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0657b6040da8fa765a9be60e5041783891af97a2 Binary files /dev/null and b/src/diffusers_/__pycache__/scheduling_utils.cpython-38.pyc differ diff --git a/src/diffusers_/__pycache__/training_utils.cpython-37.pyc b/src/diffusers_/__pycache__/training_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f6ee3a64d52f25b9c2907d4f9c8a5d016207acf Binary files /dev/null and b/src/diffusers_/__pycache__/training_utils.cpython-37.pyc differ diff --git a/src/diffusers_/__pycache__/training_utils.cpython-38.pyc b/src/diffusers_/__pycache__/training_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e058b8821a8087f221c59b171c67276867a0d1b Binary files /dev/null and b/src/diffusers_/__pycache__/training_utils.cpython-38.pyc differ diff --git a/src/diffusers_/configuration_utils.py b/src/diffusers_/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f06586b236987f04b4067ccbc92f00f49bed4589 --- /dev/null +++ b/src/diffusers_/configuration_utils.py @@ -0,0 +1,605 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ConfigMixin base class and utilities.""" +import dataclasses +import functools +import importlib +import inspect +import json +import os +import re +from collections import OrderedDict +from typing import Any, Dict, Tuple, Union + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from . import __version__ +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +class ConfigMixin: + r""" + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a config dictionary + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class will be instantiated. Make sure to only load + configuration files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the Python class. + `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments of `config`. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Instantiate a Python class from a config dictionary + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "config"} + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + if return_unused_kwargs: + return config_dict, kwargs + + return config_dict + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # 0. Copy origin config dict + original_dict = {k: v for k, v in config_dict.items()} + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove flax internal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + + # 2. Remove attributes that cannot be expected from expected config attributes + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 4. Give nice warning if unexpected values have been passed + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + # 5. Give nice info if config attributes are initiliazed to default because they have not been passed + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = {k: v for k, v in kwargs.items()} + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + # ignore flax specific attributes + if field.name in self._flax_internal_args: + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") + + # Get positional arguments aligned with kwargs + for i, arg in enumerate(args): + name = fields[i].name + new_kwargs[name] = arg + + getattr(self, "register_to_config")(**new_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init + return cls diff --git a/src/diffusers_/dynamic_modules_utils.py b/src/diffusers_/dynamic_modules_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31f3bed2ecf9794b1bf9dab265af32f98dbb7afc --- /dev/null +++ b/src/diffusers_/dynamic_modules_utils.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import importlib +import inspect +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union + +from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info + +from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging + + +COMMUNITY_PIPELINES_URL = ( + "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py" +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def get_relative_imports(module_file): + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file): + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + + if class_name is None: + return find_pipeline_class(module) + return getattr(module, class_name) + + +def find_pipeline_class(loaded_module): + """ + Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class + inheriting from `DiffusionPipeline`. + """ + from .pipeline_utils import DiffusionPipeline + + cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) + + pipeline_class = None + for cls_name, cls in cls_members.items(): + if ( + cls_name != DiffusionPipeline.__name__ + and issubclass(cls, DiffusionPipeline) + and cls.__module__.split(".")[0] != "diffusers" + ): + if pipeline_class is not None: + raise ValueError( + f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" + f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" + f" {loaded_module}." + ) + pipeline_class = cls + + return pipeline_class + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private + or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Returns: + `str`: The path to the module inside the cache. + """ + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + + if os.path.isfile(module_file_or_url): + resolved_module_file = module_file_or_url + submodule = "local" + elif pretrained_model_name_or_path.count("/") == 0: + # community pipeline on GitHub + github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path) + try: + resolved_module_file = cached_download( + github_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=False, + ) + submodule = "git" + module_file = pretrained_model_name_or_path + ".py" + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + else: + try: + # Load from URL or cache if already cached + resolved_module_file = hf_hub_download( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == "local" or submodule == "git": + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + else: + # Get the commit hash + # TODO: we will get this info in the etag soon, so retrieve it from there and not here. + if isinstance(use_auth_token, str): + token = use_auth_token + elif use_auth_token is True: + token = HfFolder.get_token() + else: + token = None + + commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha + + # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the + # benefit of versioning. + submodule_path = submodule_path / commit_hash + full_submodule = full_submodule + os.path.sep + commit_hash + create_dynamic_module(full_submodule) + + if not (submodule_path / module_file).exists(): + shutil.copy(resolved_module_file, submodule_path / module_file) + # Make sure we also have every file with relative + for module_needed in modules_needed: + if not (submodule_path / module_needed).exists(): + get_cached_module_file( + pretrained_model_name_or_path, + f"{module_needed}.py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: Optional[str] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private + or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Returns: + `type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + ```""" + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/src/diffusers_/hub_utils.py b/src/diffusers_/hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf0933a1dceff4dbc1a16d261559e104026551c --- /dev/null +++ b/src/diffusers_/hub_utils.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union +from uuid import uuid4 + +from huggingface_hub import HfFolder, Repository, whoami + +from . import __version__ +from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging +from .utils.import_utils import ( + _flax_version, + _jax_version, + _onnxruntime_version, + _torch_version, + is_flax_available, + is_modelcards_available, + is_onnx_available, + is_torch_available, +) + + +if is_modelcards_available(): + from modelcards import CardData, ModelCard + + +logger = logging.get_logger(__name__) + + +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" +SESSION_ID = uuid4().hex +DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES + + +def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: + """ + Formats a user-agent string with basic info about a request. + """ + ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" + if DISABLE_TELEMETRY: + return ua + "; telemetry/off" + if is_torch_available(): + ua += f"; torch/{_torch_version}" + if is_flax_available(): + ua += f"; jax/{_jax_version}" + ua += f"; flax/{_flax_version}" + if is_onnx_available(): + ua += f"; onnxruntime/{_onnxruntime_version}" + # CI will set this value to True + if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: + ua += "; is_ci/true" + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + return ua + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def init_git_repo(args, at_init: bool = False): + """ + Args: + Initializes a git repo in `args.hub_model_id`. + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` + and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. + """ + deprecation_message = ( + "Please use `huggingface_hub.Repository`. " + "See `examples/unconditional_image_generation/train_unconditional.py` for an example." + ) + deprecate("init_git_repo()", "0.10.0", deprecation_message) + + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + hub_token = args.hub_token if hasattr(args, "hub_token") else None + use_auth_token = True if hub_token is None else hub_token + if not hasattr(args, "hub_model_id") or args.hub_model_id is None: + repo_name = Path(args.output_dir).absolute().name + else: + repo_name = args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=hub_token) + + try: + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + private=args.hub_private_repo, + ) + except EnvironmentError: + if args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(args.output_dir) + repo = Repository( + args.output_dir, + clone_from=repo_name, + use_auth_token=use_auth_token, + ) + else: + raise + + repo.git_pull() + + # By default, ignore the checkpoint folders + if not os.path.exists(os.path.join(args.output_dir, ".gitignore")): + with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + return repo + + +def push_to_hub( + args, + pipeline, + repo: Repository, + commit_message: Optional[str] = "End of training", + blocking: bool = True, + **kwargs, +) -> str: + """ + Parameters: + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`create_model_card`]. + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the + commit and an object to track the progress of the commit if `blocking=True` + """ + deprecation_message = ( + "Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. " + "See `examples/unconditional_image_generation/train_unconditional.py` for an example." + ) + deprecate("push_to_hub()", "0.10.0", deprecation_message) + + if not hasattr(args, "hub_model_id") or args.hub_model_id is None: + model_name = Path(args.output_dir).name + else: + model_name = args.hub_model_id.split("/")[-1] + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving pipeline checkpoint to {output_dir}") + pipeline.save_pretrained(output_dir) + + # Only push from one node. + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if ( + blocking + and len(repo.command_queue) > 0 + and repo.command_queue[-1] is not None + and not repo.command_queue[-1].is_done + ): + repo.command_queue[-1]._process.kill() + + git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True) + # push separately the model card to be independent from the rest of the model + create_model_card(args, model_name=model_name) + try: + repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + +def create_model_card(args, model_name): + if not is_modelcards_available: + raise ValueError( + "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" + " install the package with `pip install modelcards`." + ) + + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + hub_token = args.hub_token if hasattr(args, "hub_token") else None + repo_name = get_full_repo_name(model_name, token=hub_token) + + model_card = ModelCard.from_template( + card_data=CardData( # Card metadata object that will be converted to YAML block + language="en", + license="apache-2.0", + library_name="diffusers", + tags=[], + datasets=args.dataset_name, + metrics=[], + ), + template_path=MODEL_CARD_TEMPLATE_PATH, + model_name=model_name, + repo_name=repo_name, + dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, + learning_rate=args.learning_rate, + train_batch_size=args.train_batch_size, + eval_batch_size=args.eval_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps + if hasattr(args, "gradient_accumulation_steps") + else None, + adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, + adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, + adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, + adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, + lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, + lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, + ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, + ema_power=args.ema_power if hasattr(args, "ema_power") else None, + ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, + mixed_precision=args.mixed_precision, + ) + + card_path = os.path.join(args.output_dir, "README.md") + model_card.save(card_path) diff --git a/src/diffusers_/modeling_utils.py b/src/diffusers_/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb0acf52f429f7bd82983fdaa542421b1da11e7 --- /dev/null +++ b/src/diffusers_/modeling_utils.py @@ -0,0 +1,693 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, device + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from . import __version__ +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + is_accelerate_available, + is_torch_version, + logging, +) + + +logger = logging.get_logger(__name__) + + +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + + +def get_parameter_device(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + try: + return torch.load(checkpoint_file, map_location="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +class ModelMixin(torch.nn.Module): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~modeling_utils.ModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + + def __init__(self): + super().__init__() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = torch.save, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~modeling_utils.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: + os.remove(full_filename) + + # Save the model + save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # Load model + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + else: + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {WEIGHTS_NAME} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {WEIGHTS_NAME}" + ) + + # restore default dtype + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + # if device_map is Non,e load the state dict on move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file) + # move the parms from meta device to cpu + for param_name, param in state_dict.items(): + set_module_tensor_to_device(model, param_name, param_device, value=param) + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by deafult the device_map is None and the weights are loaded on the CPU + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + +def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model diff --git a/src/diffusers_/pipeline_utils.py b/src/diffusers_/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71a1a6859e81a1643dceb34b53d6ce118d398309 --- /dev/null +++ b/src/diffusers_/pipeline_utils.py @@ -0,0 +1,755 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch + +import diffusers +import PIL +from huggingface_hub import snapshot_download +from packaging import version +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .dynamic_modules_utils import get_class_from_dynamic_module +from .hub_utils import http_user_agent +from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from .scheduling_utils import SCHEDULER_CONFIG_NAME +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + ONNX_WEIGHTS_NAME, + WEIGHTS_NAME, + BaseOutput, + deprecate, + is_accelerate_available, + is_torch_version, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel + + +INDEX_FILE = "diffusion_pytorch_model.bin" +CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class AudioPipelineOutput(BaseOutput): + """ + Output class for audio pipelines. + + Args: + audios (`np.ndarray`) + List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the + denoised audio samples of the diffusion pipeline. + """ + + audios: np.ndarray + + +class DiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines + and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** (`str`) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be + passed for the pipeline to function (should be overridden by subclasses). + """ + config_name = "model_index.json" + _optional_components = [] + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + if module is None: + register_dict = {name: (None, None)} + else: + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + if save_method_name is not None: + save_method = getattr(sub_model, save_method_name) + save_method(os.path.join(save_directory, pipeline_component_name)) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: + logger.warning( + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + module.to(torch_device) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + return module.device + return torch.device("cpu") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + custom_pipeline (`str`, *optional*): + + + + This is an experimental feature and is likely to change in the future. + + + + Can be either: + + - A string, the *repo id* of a custom pipeline hosted inside a model repo on + https://huggingface.co/. Valid repo ids have to be located under a user or organization name, + like `hf-internal-testing/diffusers-dummy-pipeline`. + + + + It is required that the model repo has a file, called `pipeline.py` that defines the custom + pipeline. + + + + - A string, the *file name* of a community pipeline hosted on GitHub under + https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to + match exactly the file name without `.py` located under the above link, *e.g.* + `clip_guided_stable_diffusion`. + + + + Community pipelines are always loaded from the current `main` branch of GitHub. + + + + - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. + + + + It is required that the directory has a file, called `pipeline.py` that defines the custom + pipeline. + + + + For more information on how to load and create custom pipelines, please have a look at [Loading and + Adding Custom + Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) + + torch_dtype (`str` or `torch.dtype`, *optional*): + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + + >>> # Use a different scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + custom_pipeline = kwargs.pop("custom_pipeline", None) + provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + + # make sure we don't download flax weights + ignore_patterns = "*.msgpack" + + if custom_pipeline is not None: + allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] + + if cls != DiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + user_agent = {"pipeline_class": requested_pipeline_class} + if custom_pipeline is not None: + user_agent["custom_pipeline"] = custom_pipeline + user_agent = http_user_agent(user_agent) + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + import ipdb; ipdb.set_trace() + pipeline_class = get_class_from_dynamic_module( + custom_pipeline, module_file=file_name, cache_dir=custom_pipeline + ) + elif cls != DiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config_dict["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + if class_name.startswith("Flax"): + class_name = class_name[4:] + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + if load_method_name is None: + none_module = class_obj.__module__ + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + loading_kwargs = {} + + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") + ) + + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + # raise ValueError( + # f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + # ) + + # 5. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + return model + + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - set(["self"]) + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... StableDiffusionPipeline, + ... StableDiffusionImg2ImgPipeline, + ... StableDiffusionInpaintPipeline, + ... ) + + >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) + >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) + ``` + + Returns: + A dictionaly containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers_/scheduling_utils.py b/src/diffusers_/scheduling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8aa34a55af79840defa75b983d50cdeeca65c27 --- /dev/null +++ b/src/diffusers_/scheduling_utils.py @@ -0,0 +1,154 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import torch + +from .utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class SchedulerMixin: + """ + Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing the schedluer configurations saved using + [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes diff --git a/src/diffusers_/stable_diffusion/__init__.py b/src/diffusers_/stable_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8d86d8b806693ce333f4891fafaf6495cf7551 --- /dev/null +++ b/src/diffusers_/stable_diffusion/__init__.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ..utils import ( + BaseOutput, + is_torch_available, + is_transformers_available, +) + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_stable_diffusion import StableDiffusionPipeline diff --git a/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc b/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efd61d88cdcb5451bcd3923607cd16de6bc02d58 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f5429831011b561744df75f7172886e69077b0 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2e82e9c1a95913e18a8e4ad2c83c0989015754a Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24999214a55b71f076d242d76da3a1d41adea29e Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca19e980b29e1e98eeea2a673dca5435976fa129 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b47ba63d68231308f2b9177e94e44df475007262 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a66de5ceb90b5c06eff58b2cb47bc9f5293afb8 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c4c831c7e8f8c180de0c82193a7c8c3e62c61eb Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d845713b32c5e11b98d1afe098dad5eff332a7a Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d37745844d9df646ebcc712088fcd92f702c5cb Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d231c7c9e5123f18fa68b5f2c5db1f70fc5065b4 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f280c6cda8b42cb6a1924e52984fde73ff2a3d Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27037014c06da3cb51dff5d63ce1368a5d3928bf Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d834eb4db51136de7faf4ed02341092a9f8ed8a0 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b292b42fddbf37af501eb29e4bf97fb007c2987c Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e70c380949bd3f6c4439abae6368cd98603297d3 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bec33ad65b9c15fd5e61df223a49cc5e8940096 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3073c9ac2a7dddccb5fe2aa9df00e040d5788023 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27d3a8eeabe258a55bff6beb4946782867a54232 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf95a5b2b5ce3a6fcb75fcc571217d2c91a2038f Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4255739af31ae1c98b30ba6953d68102f3c2554b Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cda99a530218816f6df2ceb32add991da121c74a Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22174ed49c44d11104e21405069d4fde6b51615b Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc b/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08bfc4784cf78454339e29faebfe219c1fd628d5 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-37.pyc b/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7a0851fe2b0c2019d154fd6eb1dbdf6a54397b4 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-37.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e3fe0776e44a397eab82957957ac171db81417 Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/__pycache__/safety_checker_flax.cpython-38.pyc b/src/diffusers_/stable_diffusion/__pycache__/safety_checker_flax.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b8f6fdac209b21af0e428689d5c664985f2d8fd Binary files /dev/null and b/src/diffusers_/stable_diffusion/__pycache__/safety_checker_flax.cpython-38.pyc differ diff --git a/src/diffusers_/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers_/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c8b03150fde8849eff76cc0e8c8df5c0d29e97 --- /dev/null +++ b/src/diffusers_/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,578 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch + +from diffusers.utils import is_accelerate_available +from packaging import version + +from ..configuration_utils import FrozenDict +from ..pipeline_utils import DiffusionPipeline +from ..utils import deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker: bool = False, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, #safety_checker, + feature_extractor=None, #feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = "", + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + text_embeddings: Optional[torch.FloatTensor] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + + # if text_embeddings is None: + # text_embeddings = self._encode_prompt( + # prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + # ) + + if num_images_per_prompt!=1: + seq_len = text_embeddings.shape[1] + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(num_images_per_prompt, seq_len, -1) + + + if do_classifier_free_guidance: + uncond_tokens = [""] + max_length = self.tokenizer.model_max_length + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + if num_images_per_prompt!=1: + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file diff --git a/src/diffusers_/stable_diffusion/safety_checker.py b/src/diffusers_/stable_diffusion/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..45835f24eddc66590a6688ae61376ccf51fd67af --- /dev/null +++ b/src/diffusers_/stable_diffusion/safety_checker.py @@ -0,0 +1,123 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/src/diffusers_/utils/__init__.py b/src/diffusers_/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e86f3b801ae44b75f0a3b0d6ee764b5763906a45 --- /dev/null +++ b/src/diffusers_/utils/__init__.py @@ -0,0 +1,86 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from .deprecation_utils import deprecate +from .import_utils import ( + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + USE_JAX, + USE_TF, + USE_TORCH, + DummyObject, + is_accelerate_available, + is_flax_available, + is_inflect_available, + is_modelcards_available, + is_onnx_available, + is_scipy_available, + is_tf_available, + is_torch_available, + is_torch_version, + is_transformers_available, + is_transformers_version, + is_unidecode_available, + requires_backends, +) +from .logging import get_logger +from .outputs import BaseOutput +from .pil_utils import PIL_INTERPOLATION + + +if is_torch_available(): + from .testing_utils import ( + floats_tensor, + load_hf_numpy, + load_image, + load_numpy, + parse_flag_from_env, + require_torch_gpu, + slow, + torch_all_close, + torch_device, + ) + + +logger = get_logger(__name__) + + +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +default_cache_path = os.path.join(hf_cache_home, "diffusers") + + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +ONNX_WEIGHTS_NAME = "model.onnx" +ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +DIFFUSERS_CACHE = default_cache_path +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + +_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", +] diff --git a/src/diffusers_/utils/__pycache__/__init__.cpython-310.pyc b/src/diffusers_/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..845bfbd7b382d55ce980c938947d7f3960f5eed2 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/__init__.cpython-37.pyc b/src/diffusers_/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d49238313847385cbf8d533e9ee61e9466cf25f Binary files /dev/null and b/src/diffusers_/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/__init__.cpython-38.pyc b/src/diffusers_/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ff53a531a4e0d02226b6ba9ab53a1c30cb57452 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-310.pyc b/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ef33af9330108bb1f3fae95a38a55c4d83243d5 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-37.pyc b/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bebd8e311f0c5fef835d83f3141f92b6262c621 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-38.pyc b/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d94a5994244b0781b3998a6bca198a9910cb8b1 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/deprecation_utils.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-37.pyc b/src/diffusers_/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e544f8f26df21ade35b00180a161eaea872a5d3 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc b/src/diffusers_/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e3a8a3a5dbd6eb9dbf97296b4df69fefddfbbb Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_flax_objects.cpython-37.pyc b/src/diffusers_/utils/__pycache__/dummy_flax_objects.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3f4d52168cb023e29bdf91d34f91ad6366cd49 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_flax_objects.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_flax_objects.cpython-38.pyc b/src/diffusers_/utils/__pycache__/dummy_flax_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b3929074c9dc7e362794fe947b6c212e48b80f2 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_flax_objects.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_pt_objects.cpython-38.pyc b/src/diffusers_/utils/__pycache__/dummy_pt_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9db795f822ee8b99f087005e78f60f705f20cbd6 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_pt_objects.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_torch_and_scipy_objects.cpython-38.pyc b/src/diffusers_/utils/__pycache__/dummy_torch_and_scipy_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f37b48337970e6257f797cc77f1bc7130884e4a6 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_torch_and_scipy_objects.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-37.pyc b/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b72a67f19ce275a2b78c4f0f313d8e90961bb160 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc b/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eed5ce469b868776064b0ca29308147f25c8c96c Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-38.pyc b/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..777b19bfada6ac6ca521d5f5172b6f8168397c44 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/import_utils.cpython-310.pyc b/src/diffusers_/utils/__pycache__/import_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b6151f96cc54f87450abe26f6d0d6b208c452a Binary files /dev/null and b/src/diffusers_/utils/__pycache__/import_utils.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/import_utils.cpython-37.pyc b/src/diffusers_/utils/__pycache__/import_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7d60b803e00f53d65583249de228f259bd8a412 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/import_utils.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/import_utils.cpython-38.pyc b/src/diffusers_/utils/__pycache__/import_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ff2a01d6c4de84f79d604f8c5ced4dd89078e46 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/import_utils.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/logging.cpython-310.pyc b/src/diffusers_/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7970fcae72eed54417514364808c9f3593d4b70 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/logging.cpython-37.pyc b/src/diffusers_/utils/__pycache__/logging.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c236f4c80ce1bb948ca33db86e4f9227ca194c5f Binary files /dev/null and b/src/diffusers_/utils/__pycache__/logging.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/logging.cpython-38.pyc b/src/diffusers_/utils/__pycache__/logging.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb13213b92fae07c0bd41ab3c0619b0f39ffcc1 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/logging.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/outputs.cpython-310.pyc b/src/diffusers_/utils/__pycache__/outputs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d949a92170bd2d9e0bf24e69ff6112fc25af05ab Binary files /dev/null and b/src/diffusers_/utils/__pycache__/outputs.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/outputs.cpython-37.pyc b/src/diffusers_/utils/__pycache__/outputs.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..612aafe00f0dd5eb62c52bfc0467cb72f766d0bb Binary files /dev/null and b/src/diffusers_/utils/__pycache__/outputs.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/outputs.cpython-38.pyc b/src/diffusers_/utils/__pycache__/outputs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c2e7c113e37a51e8573e6aa6a8e4efdad4e7fb9 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/outputs.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/pil_utils.cpython-310.pyc b/src/diffusers_/utils/__pycache__/pil_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb2fb7554f5ce3b9d50b9e9c263d58004a21b96f Binary files /dev/null and b/src/diffusers_/utils/__pycache__/pil_utils.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/pil_utils.cpython-37.pyc b/src/diffusers_/utils/__pycache__/pil_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72c2f108e984855a2e04cd4c171e350f0140604c Binary files /dev/null and b/src/diffusers_/utils/__pycache__/pil_utils.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/pil_utils.cpython-38.pyc b/src/diffusers_/utils/__pycache__/pil_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6a6faf693df5d9d6c4588f0f80fcff564c3631b Binary files /dev/null and b/src/diffusers_/utils/__pycache__/pil_utils.cpython-38.pyc differ diff --git a/src/diffusers_/utils/__pycache__/testing_utils.cpython-310.pyc b/src/diffusers_/utils/__pycache__/testing_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56b8a93a76acf69fdad0e1f78c3e48a073735b39 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/testing_utils.cpython-310.pyc differ diff --git a/src/diffusers_/utils/__pycache__/testing_utils.cpython-37.pyc b/src/diffusers_/utils/__pycache__/testing_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2607e3c0340b605f330971dbbff66360af03263a Binary files /dev/null and b/src/diffusers_/utils/__pycache__/testing_utils.cpython-37.pyc differ diff --git a/src/diffusers_/utils/__pycache__/testing_utils.cpython-38.pyc b/src/diffusers_/utils/__pycache__/testing_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3676e4e73b47dba85aec07c70d5a0953e8964a58 Binary files /dev/null and b/src/diffusers_/utils/__pycache__/testing_utils.cpython-38.pyc differ diff --git a/src/diffusers_/utils/deprecation_utils.py b/src/diffusers_/utils/deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8bfc901b31c92a93e9764506682210f5072c0a --- /dev/null +++ b/src/diffusers_/utils/deprecation_utils.py @@ -0,0 +1,49 @@ +import inspect +import warnings +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/src/diffusers_/utils/import_utils.py b/src/diffusers_/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0294b4a3d2389e4c6bf8b7e7be9a6a900f67033 --- /dev/null +++ b/src/diffusers_/utils/import_utils.py @@ -0,0 +1,374 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" +import importlib.util +import operator as op +import os +import sys +from collections import OrderedDict +from typing import Union + +from packaging import version +from packaging.version import Version, parse + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +_tf_version = "N/A" +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + +_jax_version = "N/A" +_flax_version = "N/A" +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + + +_modelcards_available = importlib.util.find_spec("modelcards") is not None +try: + _modelcards_version = importlib_metadata.version("modelcards") + logger.debug(f"Successfully imported modelcards version {_modelcards_version}") +except importlib_metadata.PackageNotFoundError: + _modelcards_available = False + + +_onnxruntime_version = "N/A" +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +if _onnx_available: + candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") + + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported transformers version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + +_accelerate_available = importlib.util.find_spec("accelerate") is not None +try: + _accelerate_version = importlib_metadata.version("accelerate") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _accelerate_available = False + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + import torch + + if torch.__version__ < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_modelcards_available(): + return _modelcards_available + + +def is_onnx_available(): + return _onnx_available + + +def is_scipy_available(): + return _scipy_available + + +def is_xformers_available(): + return _xformers_available + + +def is_accelerate_available(): + return _accelerate_available + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + ] and is_transformers_version("<", "4.25.0.dev0"): + raise ImportError( + f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install" + " git+https://github.com/huggingface/transformers \n```" + ) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) diff --git a/src/diffusers_/utils/logging.py b/src/diffusers_/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..8c1c77d10b2a6b06a0c57d4fdf1802e3bd5f705f --- /dev/null +++ b/src/diffusers_/utils/logging.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an int. + + Returns: + `int`: The logging level. + + + + 🤗 Diffusers has following logging levels: + + - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - 40: `diffusers.logging.ERROR` + - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - 20: `diffusers.logging.INFO` + - 10: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for HuggingFace Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/src/diffusers_/utils/outputs.py b/src/diffusers_/utils/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..5d902dd394ccddc408d85b48e4142facc7242550 --- /dev/null +++ b/src/diffusers_/utils/outputs.py @@ -0,0 +1,108 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +from collections import OrderedDict +from dataclasses import fields +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available + + +def is_tensor(x): + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) diff --git a/src/diffusers_/utils/pil_utils.py b/src/diffusers_/utils/pil_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39d0a15a4e2fe39fecb01951b36c43368492f983 --- /dev/null +++ b/src/diffusers_/utils/pil_utils.py @@ -0,0 +1,21 @@ +import PIL.Image +import PIL.ImageOps +from packaging import version + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } diff --git a/src/diffusers_/utils/testing_utils.py b/src/diffusers_/utils/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf398e5b6fe5b1b2c5a909bcd43a9fd772d250af --- /dev/null +++ b/src/diffusers_/utils/testing_utils.py @@ -0,0 +1,393 @@ +import inspect +import logging +import os +import random +import re +import unittest +import urllib.parse +from distutils.util import strtobool +from io import BytesIO, StringIO +from pathlib import Path +from typing import Union + +import numpy as np + +import PIL.Image +import PIL.ImageOps +import requests +from packaging import version + +from .import_utils import is_flax_available, is_onnx_available, is_torch_available + + +global_rng = random.Random() + + +if is_torch_available(): + import torch + + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( + "1.12" + ) + + if is_torch_higher_equal_than_1_12: + # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details + mps_backend_registered = hasattr(torch.backends, "mps") + torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device + + +def torch_all_close(a, b, *args, **kwargs): + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + if not torch.allclose(a, b, *args, **kwargs): + assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." + return True + + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return os.path.join(tests_dir, append_path) + else: + return tests_dir + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( + test_case + ) + + +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + +def require_onnxruntime(test_case): + """ + Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. + """ + return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) + + +def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray: + if isinstance(arry, str): + if arry.startswith("http://") or arry.startswith("https://"): + response = requests.get(arry) + response.raise_for_status() + arry = np.load(BytesIO(response.content)) + elif os.path.isfile(arry): + arry = np.load(arry) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" + ) + elif isinstance(arry, np.ndarray): + pass + else: + raise ValueError( + "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" + " ndarray." + ) + + return arry + + +def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: + """ + Args: + Loads `image` to a PIL Image. + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + Returns: + `PIL.Image.Image`: A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def load_hf_numpy(path) -> np.ndarray: + if not path.startswith("http://") or path.startswith("https://"): + path = os.path.join( + "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) + ) + + return load_numpy(path) + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should + pytest do internal changes - also it calls default internal methods of terminalreporter which + can be hijacked by various `pytest-` plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = "reports" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{id}_{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + with open(report_files["passes"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +class CaptureLogger: + """ + Args: + Context manager to capture `logging` streams + logger: 'logging` logger object + Returns: + The captured output is available via `self.out` + Example: + ```python + >>> from diffusers import logging + >>> from diffusers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" diff --git a/src/utils/__pycache__/base_pipeline.cpython-310.pyc b/src/utils/__pycache__/base_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37f97a4d9182ecd32949b6a8b99c45fea92ef1e4 Binary files /dev/null and b/src/utils/__pycache__/base_pipeline.cpython-310.pyc differ diff --git a/src/utils/__pycache__/base_pipeline.cpython-38.pyc b/src/utils/__pycache__/base_pipeline.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b39d912e25381ab5cced07e8fae56a9506323ee Binary files /dev/null and b/src/utils/__pycache__/base_pipeline.cpython-38.pyc differ diff --git a/src/utils/__pycache__/cross_attention.cpython-310.pyc b/src/utils/__pycache__/cross_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..860ecfc9666e4493be12bab72b32ef3fde33069c Binary files /dev/null and b/src/utils/__pycache__/cross_attention.cpython-310.pyc differ diff --git a/src/utils/__pycache__/ddim_inv.cpython-310.pyc b/src/utils/__pycache__/ddim_inv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c10adae8707bd13af660f6877f1322b26e3a256 Binary files /dev/null and b/src/utils/__pycache__/ddim_inv.cpython-310.pyc differ diff --git a/src/utils/__pycache__/ddim_inv.cpython-38.pyc b/src/utils/__pycache__/ddim_inv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50963a3a064c8021facb9faeb3c1649c25c1498e Binary files /dev/null and b/src/utils/__pycache__/ddim_inv.cpython-38.pyc differ diff --git a/src/utils/__pycache__/edit_directions.cpython-310.pyc b/src/utils/__pycache__/edit_directions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd16ae5d5465ee259fb6c207eec897ad10d3be99 Binary files /dev/null and b/src/utils/__pycache__/edit_directions.cpython-310.pyc differ diff --git a/src/utils/__pycache__/edit_directions.cpython-38.pyc b/src/utils/__pycache__/edit_directions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c10e8e7f17a9febf4f595f40f6e805aded18bd29 Binary files /dev/null and b/src/utils/__pycache__/edit_directions.cpython-38.pyc differ diff --git a/src/utils/__pycache__/edit_pipeline.cpython-310.pyc b/src/utils/__pycache__/edit_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e680e0a87ff03ea5cea7893e783aefe08fcf10e1 Binary files /dev/null and b/src/utils/__pycache__/edit_pipeline.cpython-310.pyc differ diff --git a/src/utils/__pycache__/edit_pipeline.cpython-38.pyc b/src/utils/__pycache__/edit_pipeline.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dbfb2f1bfbd6ae8560c581d9a0f485837ccee67 Binary files /dev/null and b/src/utils/__pycache__/edit_pipeline.cpython-38.pyc differ diff --git a/src/utils/__pycache__/gradio_utils.cpython-310.pyc b/src/utils/__pycache__/gradio_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa1d3aa4e505c0bb412bf874b92d6e98e46720ec Binary files /dev/null and b/src/utils/__pycache__/gradio_utils.cpython-310.pyc differ diff --git a/src/utils/__pycache__/gradio_utils.cpython-38.pyc b/src/utils/__pycache__/gradio_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2243a3bf3595901e75f5c841a7f8ebe053eba658 Binary files /dev/null and b/src/utils/__pycache__/gradio_utils.cpython-38.pyc differ diff --git a/src/utils/__pycache__/huggingface_utils.cpython-310.pyc b/src/utils/__pycache__/huggingface_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..035e5cbb3a30f636d3087286e4a24d3825191d14 Binary files /dev/null and b/src/utils/__pycache__/huggingface_utils.cpython-310.pyc differ diff --git a/src/utils/__pycache__/huggingface_utils.cpython-38.pyc b/src/utils/__pycache__/huggingface_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92be665590201df44f7de9473def6595352a714b Binary files /dev/null and b/src/utils/__pycache__/huggingface_utils.cpython-38.pyc differ diff --git a/src/utils/__pycache__/scheduler.cpython-310.pyc b/src/utils/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a96862bbd62f5cb5a45ef0a52e654a2353e8c1f Binary files /dev/null and b/src/utils/__pycache__/scheduler.cpython-310.pyc differ diff --git a/src/utils/__pycache__/scheduler.cpython-38.pyc b/src/utils/__pycache__/scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40bf7f24b83895b972aaffc54c180387554fd801 Binary files /dev/null and b/src/utils/__pycache__/scheduler.cpython-38.pyc differ diff --git a/src/utils/gradio_utils.py b/src/utils/gradio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4dd6bdf84ea85561247d5ad17cdda9a1744d82 --- /dev/null +++ b/src/utils/gradio_utils.py @@ -0,0 +1,892 @@ +import numpy +from tqdm import tqdm +import gradio as gr +import os + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +import numpy as np +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer +from torch import autocast +from src.diffusers_ import StableDiffusionPipeline + + + +def launch_source(): + image = Image.open('./tmp/train_images_source.png').convert("RGB") + output = temp_save([image],num_rows=1) + return output + +def launch_opt800(): + image = Image.open('./tmp/train_images_step800.png').convert("RGB") + output = temp_save([image],num_rows=1) + return output + +def launch_opt900(): + image = Image.open('./tmp/train_images_step900.png').convert("RGB") + output = temp_save([image],num_rows=1) + return output + +def launch_opt1000(): + image = Image.open('./tmp/train_images_step1000.png').convert("RGB") + output = temp_save([image],num_rows=1) + return output + +def launch_opt1100(): + image = Image.open('./tmp/train_images_step1100.png').convert("RGB") + output = temp_save([image],num_rows=1) + return output + +def launch_optimize(img_in_real, prompt, n_hiper): + + os.makedirs("tmp", exist_ok=True) + + + # Setting + accelerator = Accelerator( + gradient_accumulation_steps=1, + mixed_precision="fp16", + ) + + seed = 2220000 + set_seed(seed) + g_cuda = torch.Generator(device='cuda') + g_cuda.manual_seed(seed) + + optimizer_class = torch.optim.Adam + weight_dtype = torch.float16 + + pretrained_model_name = 'CompVis/stable-diffusion-v1-4' + + # Load pretrained models + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer", use_auth_token=True) + CLIP_text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder", use_auth_token=True) + vae = AutoencoderKL.from_pretrained(pretrained_model_name, subfolder="vae", use_auth_token=True) + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name, subfolder="unet", use_auth_token=True) + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + + + # Encode the input image. + vae.to(accelerator.device, dtype=weight_dtype) + input_image = img_in_real.convert("RGB") + img_in_real.save(os.path.join("tmp", "train_images_source.png")) + image_transforms = transforms.Compose( + [ + transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(512), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + init_image = image_transforms(input_image) + init_image = init_image[None].to(device=accelerator.device, dtype=weight_dtype) + with torch.inference_mode(): + init_latents = vae.encode(init_image).latent_dist.sample() + init_latents = 0.18215 * init_latents + + # Encode the source and target text. + CLIP_text_encoder.to(accelerator.device, dtype=weight_dtype) + text_ids_src = tokenizer(prompt,padding="max_length",truncation=True,max_length=tokenizer.model_max_length,return_tensors="pt").input_ids + text_ids_src = text_ids_src.to(device=accelerator.device) + with torch.inference_mode(): + source_embeddings = CLIP_text_encoder(text_ids_src)[0].float() + + + # del vae, CLIP_text_encoder + del vae, CLIP_text_encoder + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + + # For inference + ddim_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) + pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, scheduler=ddim_scheduler, torch_dtype=torch.float16).to("cuda") + num_samples = 1 + guidance_scale = 7.5 + num_inference_steps = 50 + height = 512 + width = 512 + + + # Optimize hiper embedding + n_hiper = int(n_hiper) + hiper_embeddings = source_embeddings[:,-n_hiper:].clone().detach() + src_embeddings = source_embeddings[:,:-n_hiper].clone().detach() + hiper_embeddings.requires_grad_(True) + + + optimizer = optimizer_class( + [hiper_embeddings], + lr=5e-3, + betas=(0.9, 0.999), + eps=1e-08, + ) + + unet, optimizer = accelerator.prepare(unet, optimizer) + + emb_train_steps = 1101 + # emb_train_steps = 201 + def train_loop(optimizer, hiper_embeddings): + inf_images=[] + for step in tqdm(range(emb_train_steps)): + with accelerator.accumulate(unet): + + noise = torch.randn_like(init_latents) + bsz = init_latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latents.device) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(init_latents, noise, timesteps) + + source_embeddings = torch.cat([src_embeddings, hiper_embeddings], 1) + noise_pred = unet(noisy_latents, timesteps, source_embeddings).sample + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + # Check inference + if step in [800,900,1000,1100]: + inf_emb = torch.cat([src_embeddings, hiper_embeddings.clone().detach()], 1) + + with autocast("cuda"), torch.inference_mode(): + images = pipe(text_embeddings=inf_emb, height=height, width=width, num_images_per_prompt=num_samples, + num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda).images + inf_images.append(images[0]) + images[0].save(os.path.join("tmp", "train_images_step{}.png".format(step))) + del images + + if step in [800,900,1000,1100]: + torch.save(hiper_embeddings.cpu(), os.path.join("tmp", "hiper_embeddings_step{}.pt".format(step))) + + + accelerator.wait_for_everyone() + + out_image = train_loop(optimizer, hiper_embeddings) + image = Image.open('./tmp/train_images_source.png').convert("RGB") + output = temp_save([image],num_rows=1) + + return "tmp", output + + + + + +def launch_main(dest, step, fpath_z_gen, seed): + seed = int(seed) + set_seed(seed) + g_cuda = torch.Generator(device='cuda') + g_cuda.manual_seed(seed) + + + # Load pretrained models + pretrained_model_name = 'CompVis/stable-diffusion-v1-4' + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) + pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, scheduler=scheduler, torch_dtype=torch.float16).to("cuda") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer", use_auth_token=True) + CLIP_text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder", use_auth_token=True) + + + # Encode the target text. + text_ids_tgt = tokenizer(dest, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt").input_ids + CLIP_text_encoder.to('cuda', dtype=torch.float32) + with torch.inference_mode(): + target_embedding = CLIP_text_encoder(text_ids_tgt.to('cuda'))[0].to('cuda') + del CLIP_text_encoder + + + # Concat target and hiper embeddings + step = int(step.replace("Step ","")) + hiper_embeddings = torch.load('./tmp/hiper_embeddings_step{}.pt'.format(step)).to("cuda") + n_hiper = hiper_embeddings.shape[1] + inference_embeddings =torch.cat([target_embedding[:, :-n_hiper], hiper_embeddings*0.8], 1) + + + # Generate target images + num_samples = 1 + guidance_scale = 7.5 + num_inference_steps = 50 + height = 512 + width = 512 + + with autocast("cuda"), torch.inference_mode(): + image_set = [] + for idx, embd in enumerate([inference_embeddings]): + for i in range(10): + images = pipe( + text_embeddings=embd, + height=height, + width=width, + num_images_per_prompt=num_samples, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=g_cuda + ).images + image_set.append(images[0]) + + out_image = temp_save(image_set,num_rows=5) + return out_image + + + + +def set_visible_true(): + return gr.update(visible=True) + +def set_visible_false(): + return gr.update(visible=False) + + +CSS_main = """ + body { + font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif; + font-weight:300; + font-size:18px; + margin-left: auto; + margin-right: auto; + padding-left: 10px; + padding-right: 10px; + width: 800px; + } + + h1 { + font-size:32px; + font-weight:300; + text-align: center; + } + + h2 { + font-size:32px; + font-weight:300; + text-align: center; + } + + #lbl_gallery_input{ + font-family: 'Helvetica', 'Arial', sans-serif; + text-align: center; + color: #fff; + font-size: 28px; + display: inline + } + + + #lbl_gallery_comparision{ + font-family: 'Helvetica', 'Arial', sans-serif; + text-align: center; + color: #fff; + font-size: 28px; + } + + .disclaimerbox { + background-color: #eee; + border: 1px solid #eeeeee; + border-radius: 10px ; + -moz-border-radius: 10px ; + -webkit-border-radius: 10px ; + padding: 20px; + } + + video.header-vid { + height: 140px; + border: 1px solid black; + border-radius: 10px ; + -moz-border-radius: 10px ; + -webkit-border-radius: 10px ; + } + + img.header-img { + height: 140px; + border: 1px solid black; + border-radius: 10px ; + -moz-border-radius: 10px ; + -webkit-border-radius: 10px ; + } + + img.rounded { + border: 1px solid #eeeeee; + border-radius: 10px ; + -moz-border-radius: 10px ; + -webkit-border-radius: 10px ; + } + + a:link + { + color: #941120; + text-decoration: none; + } + a:visited + { + color: #941120; + text-decoration: none; + } + a:hover { + color: #941120; + } + + td.dl-link { + height: 160px; + text-align: center; + font-size: 22px; + } + + .layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ + box-shadow: + 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */ + 5px 5px 0 0px #fff, /* The second layer */ + 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */ + 10px 10px 0 0px #fff, /* The third layer */ + 10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */ + 15px 15px 0 0px #fff, /* The fourth layer */ + 15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */ + 20px 20px 0 0px #fff, /* The fifth layer */ + 20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */ + 25px 25px 0 0px #fff, /* The fifth layer */ + 25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */ + margin-left: 10px; + margin-right: 45px; + } + + .paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ + box-shadow: + 0px 0px 1px 1px rgba(0,0,0,0.35); /* The top layer shadow */ + + margin-left: 10px; + margin-right: 45px; + } + + + .layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ + box-shadow: + 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */ + 5px 5px 0 0px #fff, /* The second layer */ + 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */ + 10px 10px 0 0px #fff, /* The third layer */ + 10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */ + margin-top: 5px; + margin-left: 10px; + margin-right: 30px; + margin-bottom: 5px; + } + + .vert-cent { + position: relative; + top: 50%; + transform: translateY(-50%); + } + + hr + { + border: 0; + height: 1px; + background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); + } + + .card { + /* width: 130px; + height: 195px; + width: 1px; + height: 1px; */ + position: relative; + display: inline-block; + /* margin: 50px; */ + } + .card .img-top { + display: none; + position: absolute; + top: 0; + left: 0; + z-index: 99; + } + .card:hover .img-top { + display: inline; + } + details { + user-select: none; + } + + details>summary span.icon { + width: 24px; + height: 24px; + transition: all 0.3s; + margin-left: auto; + } + + details[open] summary span.icon { + transform: rotate(180deg); + } + + summary { + display: flex; + cursor: pointer; + } + + summary::-webkit-details-marker { + display: none; + } + + ul { + display: table; + margin: 0 auto; + text-align: left; + } + + .dark { + padding: 1em 2em; + background-color: #333; + box-shadow: 3px 3px 3px #333; + border: 1px #333; + } + .column { + float: left; + width: 20%; + padding: 0.5%; + } + + .galleryImg { + transition: opacity 0.3s; + -webkit-transition: opacity 0.3s; + filter: grayscale(100%); + /* filter: blur(2px); */ + -webkit-transition : -webkit-filter 250ms linear; + /* opacity: 0.5; */ + cursor: pointer; + } + + + + .selected { + /* outline: 100px solid var(--hover-background) !important; */ + /* outline-offset: -100px; */ + filter: grayscale(0%); + -webkit-transition : -webkit-filter 250ms linear; + /*opacity: 1.0 !important; */ + } + + .galleryImg:hover { + filter: grayscale(0%); + -webkit-transition : -webkit-filter 250ms linear; + + } + + .row { + margin-bottom: 1em; + padding: 0px 1em; + } + /* Clear floats after the columns */ + .row:after { + content: ""; + display: table; + clear: both; + } + + /* The expanding image container */ + #gallery { + position: relative; + /*display: none;*/ + } + + #section_comparison{ + position: relative; + width: 100%; + height: max-content; + } + + /* SLIDER + -------------------------------------------------- */ + + .slider-container { + position: relative; + height: 384px; + width: 512px; + cursor: grab; + overflow: hidden; + margin: auto; + } + .slider-after { + display: block; + position: absolute; + top: 0; + right: 0; + bottom: 0; + left: 0; + width: 100%; + height: 100%; + overflow: hidden; + } + .slider-before { + display: block; + position: absolute; + top: 0; + /* right: 0; */ + bottom: 0; + left: 0; + width: 100%; + height: 100%; + z-index: 15; + overflow: hidden; + } + .slider-before-inset { + position: absolute; + top: 0; + bottom: 0; + left: 0; + } + .slider-after img, + .slider-before img { + object-fit: cover; + position: absolute; + width: 100%; + height: 100%; + object-position: 50% 50%; + top: 0; + bottom: 0; + left: 0; + -webkit-user-select: none; + -khtml-user-select: none; + -moz-user-select: none; + -o-user-select: none; + user-select: none; + } + + #lbl_inset_left{ + text-align: center; + position: absolute; + top: 384px; + width: 150px; + left: calc(50% - 256px); + z-index: 11; + font-size: 16px; + color: #fff; + margin: 10px; + } + .inset-before { + position: absolute; + width: 150px; + height: 150px; + box-shadow: 3px 3px 3px #333; + border: 1px #333; + border-style: solid; + z-index: 16; + top: 410px; + left: calc(50% - 256px); + margin: 10px; + font-size: 1em; + background-repeat: no-repeat; + pointer-events: none; + } + + #lbl_inset_right{ + text-align: center; + position: absolute; + top: 384px; + width: 150px; + right: calc(50% - 256px); + z-index: 11; + font-size: 16px; + color: #fff; + margin: 10px; + } + .inset-after { + position: absolute; + width: 150px; + height: 150px; + box-shadow: 3px 3px 3px #333; + border: 1px #333; + border-style: solid; + z-index: 16; + top: 410px; + right: calc(50% - 256px); + margin: 10px; + font-size: 1em; + background-repeat: no-repeat; + pointer-events: none; + } + + #lbl_inset_input{ + text-align: center; + position: absolute; + top: 384px; + width: 150px; + left: calc(50% - 256px + 150px + 20px); + z-index: 11; + font-size: 16px; + color: #fff; + margin: 10px; + } + .inset-target { + position: absolute; + width: 150px; + height: 150px; + box-shadow: 3px 3px 3px #333; + border: 1px #333; + border-style: solid; + z-index: 16; + top: 410px; + right: calc(50% - 256px + 150px + 20px); + margin: 10px; + font-size: 1em; + background-repeat: no-repeat; + pointer-events: none; + } + + .slider-beforePosition { + background: #121212; + color: #fff; + left: 0; + pointer-events: none; + border-radius: 0.2rem; + padding: 2px 10px; + } + .slider-afterPosition { + background: #121212; + color: #fff; + right: 0; + pointer-events: none; + border-radius: 0.2rem; + padding: 2px 10px; + } + .beforeLabel { + position: absolute; + top: 0; + margin: 1rem; + font-size: 1em; + -webkit-user-select: none; + -khtml-user-select: none; + -moz-user-select: none; + -o-user-select: none; + user-select: none; + } + .afterLabel { + position: absolute; + top: 0; + margin: 1rem; + font-size: 1em; + -webkit-user-select: none; + -khtml-user-select: none; + -moz-user-select: none; + -o-user-select: none; + user-select: none; + } + + .slider-handle { + height: 101px; + width: 41px; + position: absolute; + left: 50%; + top: 50%; + margin-left: -20px; + margin-top: -21px; + border: 2px solid #fff; + border-radius: 1000px; + z-index: 20; + pointer-events: none; + box-shadow: 0 0 10px rgb(12, 12, 12); + } + .handle-left-arrow, + .handle-right-arrow { + width: 0; + height: 0; + border: 6px inset transparent; + position: absolute; + top: 50%; + margin-top: -6px; + } + .handle-left-arrow { + border-right: 6px solid #fff; + left: 50%; + margin-left: -17px; + } + .handle-right-arrow { + border-left: 6px solid #fff; + right: 50%; + margin-right: -17px; + } + .slider-handle::before { + bottom: 50%; + margin-bottom: 20px; + box-shadow: 0 0 10px rgb(12, 12, 12); + } + .slider-handle::after { + top: 50%; + margin-top: 20.5px; + box-shadow: 0 0 5px rgb(12, 12, 12); + } + .slider-handle::before, + .slider-handle::after { + content: " "; + display: block; + width: 2px; + background: #fff; + height: 9999px; + position: absolute; + left: 50%; + margin-left: -1.5px; + } + + + /* + ------------------------------------------------- + The editing results shown below inversion results + ------------------------------------------------- + */ + .edit_labels{ + font-weight:500; + font-size: 24px; + color: #fff; + height: 20px; + margin-left: 20px; + position: relative; + top: 20px; + } + + .open > a:hover { + color: #555; + background-color: red; + } + + #directions { padding-top:30; padding-bottom:0; margin-bottom: 0px; height: 20px; } + #custom_task { padding-top:0; padding-bottom:0; margin-bottom: 0px; height: 20px; } + #slider_ddim {accent-color: #941120;} + #slider_ddim::-webkit-slider-thumb {background-color: #941120;} + #slider_xa {accent-color: #941120;} + #slider_xa::-webkit-slider-thumb {background-color: #941120;} + #slider_edit_mul {accent-color: #941120;} + #slider_edit_mul::-webkit-slider-thumb {background-color: #941120;} + + #input_image [data-testid="image"]{ + height: unset; + } + #input_image_synth [data-testid="image"]{ + height: unset; + } +""" + + +HTML_header = f""" + +
+ Highly Personalized Text Embedding for Image Manipulation by Stable Diffusion + + + + +
+
+ [Project page] + [Github] +
+
+
+ +
+
+

+ We present a simple yet highly effective approach to personalization using highly personalized (HiPer) text embedding by decomposing the CLIP embedding space for personalization and content manipulation. Our method does not require model fine-tuning or identifiers, yet still enables manipulation of background, texture, and motion with just a single image and target text. +
+

+
+
+ + +
+ +""" + + +HTML_input_header = f""" +

+ Step 1: select a real input image. +

+""" + + +HTML_middle_header = f""" +

+ Step 2: select the editing options. +

+""" + + +HTML_output_header = f""" +

+ Step 3: translated image! +

+""" + + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +import cv2 +from typing import Optional, Union, Tuple, List, Callable, Dict +# from tqdm.notebook import tqdm + +#codes for 'show_image' and 'text_under_image' are from +# https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb + +def show_images(images, num_rows=2, offset_ratio=0.02): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + # pil_img.save(name) + return pil_img + + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def inf_save(inf_img, name): + images = [] + for i in range(len(inf_img)): + image = np.array(inf_img[i].resize((256,256))) + image = text_under_image(image, name[i]) + images.append(image) + inf_image = show_images(np.stack(images, axis=0),num_rows=1) + return inf_image + +def temp_save(inf_img,num_rows): + images = [] + for i in range(len(inf_img)): + image = np.array(inf_img[i].resize((256,256))) + images.append(image) + inf_image = show_images(np.stack(images, axis=0),num_rows=num_rows) + return inf_image \ No newline at end of file diff --git a/src_image/1.jpeg b/src_image/1.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..fb7f5c1adfe93758a446edfd6c9e6c92ff0b8f21 Binary files /dev/null and b/src_image/1.jpeg differ diff --git a/src_image/2.png b/src_image/2.png new file mode 100644 index 0000000000000000000000000000000000000000..889af81ab2f7118befd82fd682107a8bf37a8604 Binary files /dev/null and b/src_image/2.png differ diff --git a/src_image/bird.jpeg b/src_image/bird.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..98805f64f6a10f0e23baa17c3de85f6d81b172df Binary files /dev/null and b/src_image/bird.jpeg differ