nalin0503 commited on
Commit
fc843fe
·
1 Parent(s): 6bbe574

Add Image-Morpher as folder, without assets.

Browse files
Files changed (44) hide show
  1. Image-Morpher/LICENSE.txt +10 -0
  2. Image-Morpher/README.md +133 -0
  3. Image-Morpher/__pycache__/model.cpython-310.pyc +0 -0
  4. Image-Morpher/lcm_lora/lcm_accelerated_cat3.png +0 -0
  5. Image-Morpher/lcm_lora/lcm_lora_test.py +67 -0
  6. Image-Morpher/lcm_lora/lcm_schedule.py +654 -0
  7. Image-Morpher/logs/71seconds.log +14 -0
  8. Image-Morpher/logs/97 seconds legitimate.log +23 -0
  9. Image-Morpher/logs/SD2-1 compatible, confirm on main.py.log +10 -0
  10. Image-Morpher/logs/SD2-1,vae&atnnslicing,benchmark,lcm.log +11 -0
  11. Image-Morpher/logs/best config, matmul precision /"high/" ++.log +17 -0
  12. Image-Morpher/logs/execution_20250302_224812.log +9 -0
  13. Image-Morpher/logs/execution_20250302_225132.log +9 -0
  14. Image-Morpher/logs/execution_20250302_230601.log +9 -0
  15. Image-Morpher/logs/execution_20250302_231103.log +9 -0
  16. Image-Morpher/logs/execution_20250303_000722.log +11 -0
  17. Image-Morpher/logs/execution_20250303_150757.log +9 -0
  18. Image-Morpher/logs/execution_20250309_171159.log +12 -0
  19. Image-Morpher/logs/execution_20250309_172206.log +12 -0
  20. Image-Morpher/logs/execution_20250309_173119.log +12 -0
  21. Image-Morpher/logs/execution_20250309_173628.log +13 -0
  22. Image-Morpher/logs/execution_20250309_175711.log +1 -0
  23. Image-Morpher/logs/execution_20250309_184725.log +10 -0
  24. Image-Morpher/logs/execution_20250309_215757.log +15 -0
  25. Image-Morpher/logs/execution_20250309_220805.log +16 -0
  26. Image-Morpher/logs/execution_20250309_221759.log +26 -0
  27. Image-Morpher/logs/execution_20250309_224623.log +17 -0
  28. Image-Morpher/logs/execution_20250309_224915.log +10 -0
  29. Image-Morpher/logs/execution_20250309_225023.log +0 -0
  30. Image-Morpher/logs/execution_20250309_225432.log +0 -0
  31. Image-Morpher/logs/execution_20250309_232833.log +17 -0
  32. Image-Morpher/logs/execution_20250309_235119.log +16 -0
  33. Image-Morpher/logs/execution_20250309_235339.log +15 -0
  34. Image-Morpher/logs/slight_saving_cudnn_benchmark.log +9 -0
  35. Image-Morpher/logs/strange_test.log +12 -0
  36. Image-Morpher/logs/successful_memory_optimisation1.log +17 -0
  37. Image-Morpher/main.py +162 -0
  38. Image-Morpher/model.py +647 -0
  39. Image-Morpher/requirements_diffmorpher.txt +14 -0
  40. Image-Morpher/run.sh +21 -0
  41. Image-Morpher/utils/__init__.py +0 -0
  42. Image-Morpher/utils/alpha_scheduler.py +54 -0
  43. Image-Morpher/utils/lora_utils.py +284 -0
  44. Image-Morpher/utils/model_utils.py +87 -0
Image-Morpher/LICENSE.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0 
2
+  
3
+ Copyright 2023 S-Lab
4
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
8
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
9
+ 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
10
+  
Image-Morpher/README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # to be edited.
2
+
3
+ <p align="center">
4
+ <h1 align="center">DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing</h1>
5
+ <h3 align="center">CVPR 2024</h3>
6
+ <p align="center">
7
+ <a href="https://kevin-thu.github.io/homepage/"><strong>Kaiwen Zhang</strong></a>
8
+ &nbsp;&nbsp;
9
+ <a href="https://zhouyifan.net/about/"><strong>Yifan Zhou</strong></a>
10
+ &nbsp;&nbsp;
11
+ <a href="https://sheldontsui.github.io/"><strong>Xudong Xu</strong></a>
12
+ &nbsp;&nbsp;
13
+ <a href="https://xingangpan.github.io/"><strong>Xingang Pan<sep>✉</sep></strong></a>
14
+ &nbsp;&nbsp;
15
+ <a href="http://daibo.info/"><strong>Bo Dai</strong></a>
16
+ </p>
17
+ <br>
18
+
19
+ <p align="center">
20
+ <sep>✉</sep>Corresponding Author
21
+ </p>
22
+
23
+ <div align="center">
24
+ <img src="./assets/teaser.gif", width="500">
25
+ </div>
26
+
27
+ <p align="center">
28
+ <a href="https://arxiv.org/abs/2312.07409"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2312.07409-b31b1b.svg"></a>
29
+ <a href="https://kevin-thu.github.io/DiffMorpher_page/"><img alt='page' src="https://img.shields.io/badge/Project-Website-orange"></a>
30
+ <a href="https://twitter.com/sze68zkw"><img alt='Twitter' src="https://img.shields.io/twitter/follow/sze68zkw?label=%40KaiwenZhang"></a>
31
+ <a href="https://twitter.com/XingangP"><img alt='Twitter' src="https://img.shields.io/twitter/follow/XingangP?label=%40XingangPan"></a>
32
+ </p>
33
+ <br>
34
+ </p>
35
+
36
+ ## Web Demos
37
+
38
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/KaiwenZhang/DiffMorpher)
39
+
40
+ <p align="left">
41
+ <a href="https://huggingface.co/spaces/Kevin-thu/DiffMorpher"><img alt="Huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DiffMorpher-orange"></a>
42
+ </p>
43
+
44
+ <!-- Great thanks to [OpenXLab](https://openxlab.org.cn/home) for the NVIDIA A100 GPU support! -->
45
+
46
+ ## Requirements
47
+ To install the requirements, run the following in your environment first:
48
+ ```bash
49
+ pip install -r requirements.txt
50
+ ```
51
+ To run the code with CUDA properly, you can comment out `torch` and `torchvision` in `requirement.txt`, and install the appropriate version of `torch` and `torchvision` according to the instructions on [PyTorch](https://pytorch.org/get-started/locally/).
52
+
53
+ You can also download the pretrained model *Stable Diffusion v2.1-base* from [Huggingface](https://huggingface.co/stabilityai/stable-diffusion-2-1-base), and specify the `model_path` to your local directory.
54
+
55
+ ## Run Gradio UI
56
+ To start the Gradio UI of DiffMorpher, run the following in your environment:
57
+ ```bash
58
+ python app.py
59
+ ```
60
+ Then, by default, you can access the UI at [http://127.0.0.1:7860](http://127.0.0.1:7860).
61
+
62
+ ## Run the code
63
+ You can also run the code with the following command:
64
+ ```bash
65
+ python main.py \
66
+ --image_path_0 [image_path_0] --image_path_1 [image_path_1] \
67
+ --prompt_0 [prompt_0] --prompt_1 [prompt_1] \
68
+ --output_path [output_path] \
69
+ --use_adain --use_reschedule --save_inter
70
+ ```
71
+ The script also supports the following options:
72
+
73
+ - `--image_path_0`: Path of the first image (default: "")
74
+ - `--prompt_0`: Prompt of the first image (default: "")
75
+ - `--image_path_1`: Path of the second image (default: "")
76
+ - `--prompt_1`: Prompt of the second image (default: "")
77
+ - `--model_path`: Pretrained model path (default: "stabilityai/stable-diffusion-2-1-base")
78
+ - `--output_path`: Path of the output image (default: "")
79
+ - `--save_lora_dir`: Path of the output lora directory (default: "./lora")
80
+ - `--load_lora_path_0`: Path of the lora directory of the first image (default: "")
81
+ - `--load_lora_path_1`: Path of the lora directory of the second image (default: "")
82
+ - `--use_adain`: Use AdaIN (default: False)
83
+ - `--use_reschedule`: Use reschedule sampling (default: False)
84
+ - `--lamb`: Hyperparameter $\lambda \in [0,1]$ for self-attention replacement, where a larger $\lambda$ indicates more replacements (default: 0.6)
85
+ - `--fix_lora_value`: Fix lora value (default: LoRA Interpolation, not fixed)
86
+ - `--save_inter`: Save intermediate results (default: False)
87
+ - `--num_frames`: Number of frames to generate (default: 50)
88
+ - `--duration`: Duration of each frame (default: 50)
89
+
90
+ Examples:
91
+ ```bash
92
+ python main.py \
93
+ --image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \
94
+ --prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \
95
+ --output_path "./results/Trump_Biden" \
96
+ --use_adain --use_reschedule --save_inter
97
+ ```
98
+
99
+ ```bash
100
+ python main.py \
101
+ --image_path_0 ./assets/vangogh.jpg --image_path_1 ./assets/pearlgirl.jpg \
102
+ --prompt_0 "An oil painting of a man" --prompt_1 "An oil painting of a woman" \
103
+ --output_path "./results/vangogh_pearlgirl" \
104
+ --use_adain --use_reschedule --save_inter
105
+ ```
106
+
107
+ ```bash
108
+ python main.py \
109
+ --image_path_0 ./assets/lion.png --image_path_1 ./assets/tiger.png \
110
+ --prompt_0 "A photo of a lion" --prompt_1 "A photo of a tiger" \
111
+ --output_path "./results/lion_tiger" \
112
+ --use_adain --use_reschedule --save_inter
113
+ ```
114
+
115
+ ## MorphBench
116
+ To evaluate the effectiveness of our methods, we present *MorphBench*, the first benchmark dataset for assessing image morphing of general objects. You can download the dataset from [Google Drive](https://drive.google.com/file/d/1NWPzJhOgP-udP_wYbd0selRG4cu8xsu4/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1J3xE3OJdEhKyoc1QObyYaA?pwd=putk).
117
+
118
+
119
+ ## License
120
+ The code related to the DiffMorpher algorithm is licensed under [LICENSE](LICENSE.txt).
121
+
122
+ However, this project is mostly built on the open-sourse library [diffusers](https://github.com/huggingface/diffusers), which is under a separate license terms [Apache License 2.0](https://github.com/huggingface/diffusers/blob/main/LICENSE). (Cheers to the community as well!)
123
+
124
+ ## Citation
125
+
126
+ ```bibtex
127
+ @article{zhang2023diffmorpher,
128
+ title={DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing},
129
+ author={Zhang, Kaiwen and Zhou, Yifan and Xu, Xudong and Pan, Xingang and Dai, Bo},
130
+ journal={arXiv preprint arXiv:2312.07409},
131
+ year={2023}
132
+ }
133
+ ```
Image-Morpher/__pycache__/model.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
Image-Morpher/lcm_lora/lcm_accelerated_cat3.png ADDED
Image-Morpher/lcm_lora/lcm_lora_test.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from diffusers import AutoPipelineForText2Image
3
+ from lcm_schedule import LCMScheduler
4
+ from diffusers import DiffusionPipeline
5
+
6
+ # Load the SD v1.5 base model
7
+ pipe = DiffusionPipeline.from_pretrained(
8
+ # "runwayml/stable-diffusion-v1-5",
9
+ "Lykon/dreamshaper-7", # seems to produce better results, its a fine-tuned sd1-5
10
+ torch_dtype=torch.float16,
11
+ # variant="fp16", # Removing variant since torch_dtype is set
12
+ ).to("cuda")
13
+
14
+ # Set the LCMScheduler
15
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
16
+
17
+ # Load LCM LoRA for SD v1.5
18
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
19
+ # pipe.fuse_lora() # my ver of Stable Diffusion Pipeline above does not have this...
20
+
21
+ # Ensure the pipeline is on CUDA with the proper dtype
22
+ # pipe.to(device="cuda", dtype=torch.float16) # Removed device argument
23
+
24
+ prompt = "A hyperrealistic portrait of a cat, highly detailed, cute pet"
25
+ negative_prompt = "blurry, low quality"
26
+ generator = torch.manual_seed(0) # for reproducibility
27
+
28
+ # Run inference (using very few steps as typical for LCM accelerated inference)
29
+ image = pipe(prompt,
30
+ negative_prompt=negative_prompt, # this works well for dreamshaper ! best output is with this, num_inference = 8, and guidance =0 on dreamshaper
31
+ num_inference_steps=8, # 8 seems to work best for both... but finetune it.
32
+ guidance_scale=0, # the dreamshaper model works better with guidance_scale = 0!! disabled essentially.
33
+ generator=generator).images[0]
34
+
35
+ # Save the resulting image to disk
36
+ image.save("lcm_accelerated_cat3.png")
37
+
38
+ # print(pipe.scheduler)
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+ # model_id = "Lykon/dreamshaper-7"
53
+ # adapter_id = "latent-consistency/lcm-lora-sdv1-5"
54
+
55
+ # pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
56
+ # pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
57
+ # pipe.to("cuda")
58
+
59
+ # # load and fuse lcm lora
60
+ # pipe.load_lora_weights(adapter_id)
61
+ # pipe.fuse_lora()
62
+
63
+
64
+ # prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
65
+
66
+ # # disable guidance_scale by passing 0
67
+ # image = pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0).images[0]
Image-Morpher/lcm_lora/lcm_schedule.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ # import math
19
+ # from dataclasses import dataclass
20
+ # from typing import List, Optional, Tuple, Union
21
+
22
+ # import numpy as np
23
+ # import torch
24
+
25
+ # from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, logging
27
+ # from diffusers.utils.torch_utils import randn_tensor
28
+ # from diffusers.scheduling_utils import SchedulerMixin
29
+ from diffusers.schedulers.scheduling_ddim import *
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class LCMSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
45
+ `pred_original_sample` can be used to preview progress or for guidance.
46
+ """
47
+
48
+ prev_sample: torch.Tensor
49
+ denoised: Optional[torch.Tensor] = None
50
+
51
+
52
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
53
+ def betas_for_alpha_bar(
54
+ num_diffusion_timesteps,
55
+ max_beta=0.999,
56
+ alpha_transform_type="cosine",
57
+ ):
58
+ """
59
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
60
+ (1-beta) over time from t = [0,1].
61
+
62
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
63
+ to that part of the diffusion process.
64
+
65
+
66
+ Args:
67
+ num_diffusion_timesteps (`int`): the number of betas to produce.
68
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
69
+ prevent singularities.
70
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
71
+ Choose from `cosine` or `exp`
72
+
73
+ Returns:
74
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
75
+ """
76
+ if alpha_transform_type == "cosine":
77
+
78
+ def alpha_bar_fn(t):
79
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
80
+
81
+ elif alpha_transform_type == "exp":
82
+
83
+ def alpha_bar_fn(t):
84
+ return math.exp(t * -12.0)
85
+
86
+ else:
87
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
88
+
89
+ betas = []
90
+ for i in range(num_diffusion_timesteps):
91
+ t1 = i / num_diffusion_timesteps
92
+ t2 = (i + 1) / num_diffusion_timesteps
93
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
94
+ return torch.tensor(betas, dtype=torch.float32)
95
+
96
+
97
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
98
+ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
99
+ """
100
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
101
+
102
+
103
+ Args:
104
+ betas (`torch.Tensor`):
105
+ the betas that the scheduler is being initialized with.
106
+
107
+ Returns:
108
+ `torch.Tensor`: rescaled betas with zero terminal SNR
109
+ """
110
+ # Convert betas to alphas_bar_sqrt
111
+ alphas = 1.0 - betas
112
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
113
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
114
+
115
+ # Store old values.
116
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
117
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
118
+
119
+ # Shift so the last timestep is zero.
120
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
121
+
122
+ # Scale so the first timestep is back to the old value.
123
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
124
+
125
+ # Convert alphas_bar_sqrt to betas
126
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
127
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
128
+ alphas = torch.cat([alphas_bar[0:1], alphas])
129
+ betas = 1 - alphas
130
+
131
+ return betas
132
+
133
+
134
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
135
+ """
136
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
137
+ non-Markovian guidance.
138
+
139
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
140
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
141
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
142
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
143
+
144
+ Args:
145
+ num_train_timesteps (`int`, defaults to 1000):
146
+ The number of diffusion steps to train the model.
147
+ beta_start (`float`, defaults to 0.0001):
148
+ The starting `beta` value of inference.
149
+ beta_end (`float`, defaults to 0.02):
150
+ The final `beta` value.
151
+ beta_schedule (`str`, defaults to `"linear"`):
152
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
153
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
154
+ trained_betas (`np.ndarray`, *optional*):
155
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
156
+ original_inference_steps (`int`, *optional*, defaults to 50):
157
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
158
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
159
+ clip_sample (`bool`, defaults to `True`):
160
+ Clip the predicted sample for numerical stability.
161
+ clip_sample_range (`float`, defaults to 1.0):
162
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
163
+ set_alpha_to_one (`bool`, defaults to `True`):
164
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
165
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
166
+ otherwise it uses the alpha value at step 0.
167
+ steps_offset (`int`, defaults to 0):
168
+ An offset added to the inference steps, as required by some model families.
169
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
170
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
171
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
172
+ Video](https://imagen.research.google/video/paper.pdf) paper).
173
+ thresholding (`bool`, defaults to `False`):
174
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
175
+ as Stable Diffusion.
176
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
177
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
178
+ sample_max_value (`float`, defaults to 1.0):
179
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
180
+ timestep_spacing (`str`, defaults to `"leading"`):
181
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
182
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
183
+ timestep_scaling (`float`, defaults to 10.0):
184
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
185
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
186
+ error at the default of `10.0` is already pretty small).
187
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
188
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
189
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
190
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
191
+ """
192
+
193
+ order = 1
194
+
195
+ @register_to_config
196
+ def __init__(
197
+ self,
198
+ num_train_timesteps: int = 1000,
199
+ beta_start: float = 0.00085,
200
+ beta_end: float = 0.012,
201
+ beta_schedule: str = "scaled_linear",
202
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
203
+ original_inference_steps: int = 50,
204
+ clip_sample: bool = False,
205
+ clip_sample_range: float = 1.0,
206
+ set_alpha_to_one: bool = True,
207
+ steps_offset: int = 0,
208
+ prediction_type: str = "epsilon",
209
+ thresholding: bool = False,
210
+ dynamic_thresholding_ratio: float = 0.995,
211
+ sample_max_value: float = 1.0,
212
+ timestep_spacing: str = "leading",
213
+ timestep_scaling: float = 10.0,
214
+ rescale_betas_zero_snr: bool = False,
215
+ ):
216
+ if trained_betas is not None:
217
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
218
+ elif beta_schedule == "linear":
219
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
220
+ elif beta_schedule == "scaled_linear":
221
+ # this schedule is very specific to the latent diffusion model.
222
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
223
+ elif beta_schedule == "squaredcos_cap_v2":
224
+ # Glide cosine schedule
225
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
226
+ else:
227
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
228
+
229
+ # Rescale for zero SNR
230
+ if rescale_betas_zero_snr:
231
+ self.betas = rescale_zero_terminal_snr(self.betas)
232
+
233
+ self.alphas = 1.0 - self.betas
234
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
235
+
236
+ # At every step in ddim, we are looking into the previous alphas_cumprod
237
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
238
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
239
+ # whether we use the final alpha of the "non-previous" one.
240
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
241
+
242
+ # standard deviation of the initial noise distribution
243
+ self.init_noise_sigma = 1.0
244
+
245
+ # setable values
246
+ self.num_inference_steps = None
247
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
248
+ self.custom_timesteps = False
249
+
250
+ self._step_index = None
251
+ self._begin_index = None
252
+
253
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
254
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
255
+ if schedule_timesteps is None:
256
+ schedule_timesteps = self.timesteps
257
+
258
+ indices = (schedule_timesteps == timestep).nonzero()
259
+
260
+ # The sigma index that is taken for the **very** first `step`
261
+ # is always the second index (or the last index if there is only 1)
262
+ # This way we can ensure we don't accidentally skip a sigma in
263
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
264
+ pos = 1 if len(indices) > 1 else 0
265
+
266
+ return indices[pos].item()
267
+
268
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
269
+ def _init_step_index(self, timestep):
270
+ if self.begin_index is None:
271
+ if isinstance(timestep, torch.Tensor):
272
+ timestep = timestep.to(self.timesteps.device)
273
+ self._step_index = self.index_for_timestep(timestep)
274
+ else:
275
+ self._step_index = self._begin_index
276
+
277
+ @property
278
+ def step_index(self):
279
+ return self._step_index
280
+
281
+ @property
282
+ def begin_index(self):
283
+ """
284
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
285
+ """
286
+ return self._begin_index
287
+
288
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
289
+ def set_begin_index(self, begin_index: int = 0):
290
+ """
291
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
292
+
293
+ Args:
294
+ begin_index (`int`):
295
+ The begin index for the scheduler.
296
+ """
297
+ self._begin_index = begin_index
298
+
299
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
300
+ """
301
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
302
+ current timestep.
303
+
304
+ Args:
305
+ sample (`torch.Tensor`):
306
+ The input sample.
307
+ timestep (`int`, *optional*):
308
+ The current timestep in the diffusion chain.
309
+ Returns:
310
+ `torch.Tensor`:
311
+ A scaled input sample.
312
+ """
313
+ return sample
314
+
315
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
316
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
317
+ """
318
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
319
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
320
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
321
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
322
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
323
+
324
+ https://arxiv.org/abs/2205.11487
325
+ """
326
+ dtype = sample.dtype
327
+ batch_size, channels, *remaining_dims = sample.shape
328
+
329
+ if dtype not in (torch.float32, torch.float64):
330
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
331
+
332
+ # Flatten sample for doing quantile calculation along each image
333
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
334
+
335
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
336
+
337
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
338
+ s = torch.clamp(
339
+ s, min=1, max=self.config.sample_max_value
340
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
341
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
342
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
343
+
344
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
345
+ sample = sample.to(dtype)
346
+
347
+ return sample
348
+
349
+ def set_timesteps(
350
+ self,
351
+ num_inference_steps: Optional[int] = None,
352
+ device: Union[str, torch.device] = None,
353
+ original_inference_steps: Optional[int] = None,
354
+ timesteps: Optional[List[int]] = None,
355
+ strength: int = 1.0,
356
+ ):
357
+ """
358
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
359
+
360
+ Args:
361
+ num_inference_steps (`int`, *optional*):
362
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
363
+ `timesteps` must be `None`.
364
+ device (`str` or `torch.device`, *optional*):
365
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
366
+ original_inference_steps (`int`, *optional*):
367
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
368
+ schedule (which is different from the standard `diffusers` implementation). We will then take
369
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
370
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
371
+ timesteps (`List[int]`, *optional*):
372
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
373
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
374
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
375
+ """
376
+ # 0. Check inputs
377
+ if num_inference_steps is None and timesteps is None:
378
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
379
+
380
+ if num_inference_steps is not None and timesteps is not None:
381
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
382
+
383
+ # 1. Calculate the LCM original training/distillation timestep schedule.
384
+ original_steps = (
385
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
386
+ )
387
+
388
+ if original_steps > self.config.num_train_timesteps:
389
+ raise ValueError(
390
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
391
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
392
+ f" maximal {self.config.num_train_timesteps} timesteps."
393
+ )
394
+
395
+ # LCM Timesteps Setting
396
+ # The skipping step parameter k from the paper.
397
+ k = self.config.num_train_timesteps // original_steps
398
+ # LCM Training/Distillation Steps Schedule
399
+ # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
400
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
401
+
402
+ # 2. Calculate the LCM inference timestep schedule.
403
+ if timesteps is not None:
404
+ # 2.1 Handle custom timestep schedules.
405
+ train_timesteps = set(lcm_origin_timesteps)
406
+ non_train_timesteps = []
407
+ for i in range(1, len(timesteps)):
408
+ if timesteps[i] >= timesteps[i - 1]:
409
+ raise ValueError("`custom_timesteps` must be in descending order.")
410
+
411
+ if timesteps[i] not in train_timesteps:
412
+ non_train_timesteps.append(timesteps[i])
413
+
414
+ if timesteps[0] >= self.config.num_train_timesteps:
415
+ raise ValueError(
416
+ f"`timesteps` must start before `self.config.train_timesteps`:"
417
+ f" {self.config.num_train_timesteps}."
418
+ )
419
+
420
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
421
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
422
+ logger.warning(
423
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
424
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
425
+ f" unexpected results when using this timestep schedule."
426
+ )
427
+
428
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
429
+ if non_train_timesteps:
430
+ logger.warning(
431
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
432
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
433
+ f" when using this timestep schedule."
434
+ )
435
+
436
+ # Raise warning if custom timestep schedule is longer than original_steps
437
+ if len(timesteps) > original_steps:
438
+ logger.warning(
439
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
440
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
441
+ f" unexpected results when using this timestep schedule."
442
+ )
443
+
444
+ timesteps = np.array(timesteps, dtype=np.int64)
445
+ self.num_inference_steps = len(timesteps)
446
+ self.custom_timesteps = True
447
+
448
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
449
+ init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
450
+ t_start = max(self.num_inference_steps - init_timestep, 0)
451
+ timesteps = timesteps[t_start * self.order :]
452
+ # TODO: also reset self.num_inference_steps?
453
+ else:
454
+ # 2.2 Create the "standard" LCM inference timestep schedule.
455
+ if num_inference_steps > self.config.num_train_timesteps:
456
+ raise ValueError(
457
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
458
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
459
+ f" maximal {self.config.num_train_timesteps} timesteps."
460
+ )
461
+
462
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
463
+
464
+ if skipping_step < 1:
465
+ raise ValueError(
466
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
467
+ )
468
+
469
+ self.num_inference_steps = num_inference_steps
470
+
471
+ if num_inference_steps > original_steps:
472
+ raise ValueError(
473
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
474
+ f" {original_steps} because the final timestep schedule will be a subset of the"
475
+ f" `original_inference_steps`-sized initial timestep schedule."
476
+ )
477
+
478
+ # LCM Inference Steps Schedule
479
+ lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
480
+ # Select (approximately) evenly spaced indices from lcm_origin_timesteps.
481
+ inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
482
+ inference_indices = np.floor(inference_indices).astype(np.int64)
483
+ timesteps = lcm_origin_timesteps[inference_indices]
484
+
485
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
486
+
487
+ self._step_index = None
488
+ self._begin_index = None
489
+
490
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
491
+ self.sigma_data = 0.5 # Default: 0.5
492
+ scaled_timestep = timestep * self.config.timestep_scaling
493
+
494
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
495
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
496
+ return c_skip, c_out
497
+
498
+ def step(
499
+ self,
500
+ model_output: torch.Tensor,
501
+ timestep: int,
502
+ sample: torch.Tensor,
503
+ generator: Optional[torch.Generator] = None,
504
+ return_dict: bool = True,
505
+ ) -> Union[LCMSchedulerOutput, Tuple]:
506
+ """
507
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
508
+ process from the learned model outputs (most often the predicted noise).
509
+
510
+ Args:
511
+ model_output (`torch.Tensor`):
512
+ The direct output from learned diffusion model.
513
+ timestep (`float`):
514
+ The current discrete timestep in the diffusion chain.
515
+ sample (`torch.Tensor`):
516
+ A current instance of a sample created by the diffusion process.
517
+ generator (`torch.Generator`, *optional*):
518
+ A random number generator.
519
+ return_dict (`bool`, *optional*, defaults to `True`):
520
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
521
+ Returns:
522
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
523
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
524
+ tuple is returned where the first element is the sample tensor.
525
+ """
526
+ if self.num_inference_steps is None:
527
+ raise ValueError(
528
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
529
+ )
530
+
531
+ if self.step_index is None:
532
+ self._init_step_index(timestep)
533
+
534
+ # 1. get previous step value
535
+ prev_step_index = self.step_index + 1
536
+ if prev_step_index < len(self.timesteps):
537
+ prev_timestep = self.timesteps[prev_step_index]
538
+ else:
539
+ prev_timestep = timestep
540
+
541
+ # 2. compute alphas, betas
542
+ alpha_prod_t = self.alphas_cumprod[timestep]
543
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
544
+
545
+ beta_prod_t = 1 - alpha_prod_t
546
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
547
+
548
+ # 3. Get scalings for boundary conditions
549
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
550
+
551
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
552
+ if self.config.prediction_type == "epsilon": # noise-prediction
553
+ predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
554
+ elif self.config.prediction_type == "sample": # x-prediction
555
+ predicted_original_sample = model_output
556
+ elif self.config.prediction_type == "v_prediction": # v-prediction
557
+ predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
558
+ else:
559
+ raise ValueError(
560
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
561
+ " `v_prediction` for `LCMScheduler`."
562
+ )
563
+
564
+ # 5. Clip or threshold "predicted x_0"
565
+ if self.config.thresholding:
566
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
567
+ elif self.config.clip_sample:
568
+ predicted_original_sample = predicted_original_sample.clamp(
569
+ -self.config.clip_sample_range, self.config.clip_sample_range
570
+ )
571
+
572
+ # 6. Denoise model output using boundary conditions
573
+ denoised = c_out * predicted_original_sample + c_skip * sample
574
+
575
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
576
+ # Noise is not used on the final timestep of the timestep schedule.
577
+ # This also means that noise is not used for one-step sampling.
578
+ if self.step_index != self.num_inference_steps - 1:
579
+ noise = randn_tensor(
580
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
581
+ )
582
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
583
+ else:
584
+ prev_sample = denoised
585
+
586
+ # upon completion increase step index by one
587
+ self._step_index += 1
588
+
589
+ if not return_dict:
590
+ return (prev_sample, denoised)
591
+
592
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
593
+
594
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
595
+ def add_noise(
596
+ self,
597
+ original_samples: torch.Tensor,
598
+ noise: torch.Tensor,
599
+ timesteps: torch.IntTensor,
600
+ ) -> torch.Tensor:
601
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
602
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
603
+ # for the subsequent add_noise calls
604
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
605
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
606
+ timesteps = timesteps.to(original_samples.device)
607
+
608
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
609
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
610
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
611
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
612
+
613
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
614
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
615
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
616
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
617
+
618
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
619
+ return noisy_samples
620
+
621
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
622
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
623
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
624
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
625
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
626
+ timesteps = timesteps.to(sample.device)
627
+
628
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
629
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
630
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
631
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
632
+
633
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
634
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
635
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
636
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
637
+
638
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
639
+ return velocity
640
+
641
+ def __len__(self):
642
+ return self.config.num_train_timesteps
643
+
644
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645
+ def previous_timestep(self, timestep):
646
+ if self.custom_timesteps or self.num_inference_steps:
647
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
648
+ if index == self.timesteps.shape[0] - 1:
649
+ prev_t = torch.tensor(-1)
650
+ else:
651
+ prev_t = self.timesteps[index + 1]
652
+ else:
653
+ prev_t = timestep - 1
654
+ return prev_t
Image-Morpher/logs/71seconds.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:59:17,957 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+ 2025-03-09 18:00:29,214 - INFO - Total execution time: 71.52 seconds
3
+ 2025-03-09 18:00:29,215 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
4
+ 2025-03-09 18:00:29,215 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 18:00:29,215 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 18:00:29,215 - INFO - Use LCM: True
7
+ 2025-03-09 18:00:29,215 - INFO - Number of inference steps: 8
8
+ 2025-03-09 18:00:29,215 - INFO - Guidance scale: 1
9
+
10
+ 71 seconds with standard 2-1.. and the vae attention slicing ?! and cudnn benchmark.
11
+
12
+ lcm-lora works with the stabilityai/sd2-1 model card?
13
+
14
+ i think for this run the lora was already partially trained so it picked it up and trained the rest, so it was quicker
Image-Morpher/logs/97 seconds legitimate.log ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 23:23:30,516 - INFO - Total execution time: 97.10 seconds
2
+ 2025-03-09 23:23:30,516 - INFO - Model Path: stabilityai/stable-diffusion-2-1-base
3
+ 2025-03-09 23:23:30,516 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 23:23:30,516 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 23:23:30,516 - INFO - Use LCM: True
6
+ 2025-03-09 23:23:30,516 - INFO - Number of inference steps: 8
7
+ 2025-03-09 23:23:30,516 - INFO - Guidance scale: 1
8
+
9
+
10
+ so first time under 100 seconds.
11
+
12
+
13
+ this was with
14
+ # Configure compiler settings (safe defaults)
15
+ inductor.config.force_fuse_int_mm_with_mul = True # Better for diffusion models
16
+
17
+ and adding
18
+ # Adding these AFTER device movement, instead of before, so it can optimise convolutions for the specific GPU / cuda kernel
19
+ torch.backends.cudnn.benchmark = True
20
+ torch.set_float32_matmul_precision("high") # Better for modern GPUs
21
+
22
+
23
+ In all i was able to shave off about 20-30 seconds from the optimisations, from the lcm-lora it was about a minute (?) so 1.5min saved
Image-Morpher/logs/SD2-1 compatible, confirm on main.py.log ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 18:44:19,754 - INFO - Total execution time: 104.91 seconds
2
+ 2025-03-09 18:44:19,754 - INFO - Model Path: stabilityai/stable-diffusion-2-1-base
3
+ 2025-03-09 18:44:19,754 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 18:44:19,755 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 18:44:19,755 - INFO - Use LCM: True
6
+ 2025-03-09 18:44:19,755 - INFO - Number of inference steps: 8
7
+ 2025-03-09 18:44:19,755 - INFO - Guidance scale: 1
8
+
9
+
10
+ confirm that sd2-1 is faster than sdv1-5 in training, saved about 15-20s here.
Image-Morpher/logs/SD2-1,vae&atnnslicing,benchmark,lcm.log ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 18:05:43,787 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+ 2025-03-09 18:07:28,182 - INFO - Total execution time: 104.66 seconds
3
+ 2025-03-09 18:07:28,182 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
4
+ 2025-03-09 18:07:28,182 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 18:07:28,182 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 18:07:28,182 - INFO - Use LCM: True
7
+ 2025-03-09 18:07:28,182 - INFO - Number of inference steps: 8
8
+ 2025-03-09 18:07:28,182 - INFO - Guidance scale: 1
9
+
10
+ yup it does seem that it switched to 2-1, its compatible with lcm-lora 1-5, and that the vae and attention slicing were ENABLED
11
+ note that cudnn benchmark was also ENABLED
Image-Morpher/logs/best config, matmul precision /"high/" ++.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 23:57:53,227 - INFO - Total execution time: 97.25 seconds
2
+ 2025-03-09 23:57:53,227 - INFO - Model Path: stabilityai/stable-diffusion-2-1-base
3
+ 2025-03-09 23:57:53,227 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 23:57:53,227 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 23:57:53,227 - INFO - Use LCM: True
6
+ 2025-03-09 23:57:53,228 - INFO - Number of inference steps: 8
7
+ 2025-03-09 23:57:53,228 - INFO - Guidance scale: 1
8
+
9
+
10
+
11
+ best result, this was with
12
+ torch.set_float32_matmul_precision("high") # Better for modern GPUs
13
+
14
+ it instructs cuda to not be so precise in intermediate computations, does some approximations, rounding that would not be crucial in a
15
+ deep learning workload..... on "medium" or "low" settings they have additional steps to ensure this precision apparently.
16
+
17
+ and this was particularly useful after moving pipeline.to("cuda").
Image-Morpher/logs/execution_20250302_224812.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-02 22:50:16,961 - INFO - Total execution time: 124.71 seconds
2
+ 2025-03-02 22:50:16,961 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-02 22:50:16,961 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-02 22:50:16,961 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-02 22:50:16,962 - INFO - Number of inference steps: 8
6
+ 2025-03-02 22:50:16,962 - INFO - Guidance scale: 1
7
+ 2025-03-02 22:50:16,962 - INFO - Use LCM: True
8
+
9
+ base
Image-Morpher/logs/execution_20250302_225132.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-02 22:53:49,437 - INFO - Total execution time: 136.99 seconds
2
+ 2025-03-02 22:53:49,438 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-02 22:53:49,438 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-02 22:53:49,438 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-02 22:53:49,438 - INFO - Number of inference steps: 8
6
+ 2025-03-02 22:53:49,438 - INFO - Guidance scale: 1
7
+ 2025-03-02 22:53:49,438 - INFO - Use LCM: True
8
+
9
+ cudnn benchamrk and precison 32
Image-Morpher/logs/execution_20250302_230601.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-02 23:08:10,022 - INFO - Total execution time: 128.30 seconds
2
+ 2025-03-02 23:08:10,022 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-02 23:08:10,022 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-02 23:08:10,022 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-02 23:08:10,022 - INFO - Number of inference steps: 8
6
+ 2025-03-02 23:08:10,022 - INFO - Guidance scale: 1
7
+ 2025-03-02 23:08:10,022 - INFO - Use LCM: True
8
+
9
+ pipeline memory format changes, bad results
Image-Morpher/logs/execution_20250302_231103.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-02 23:13:22,730 - INFO - Total execution time: 139.31 seconds
2
+ 2025-03-02 23:13:22,758 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-02 23:13:22,758 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-02 23:13:22,758 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-02 23:13:22,758 - INFO - Number of inference steps: 8
6
+ 2025-03-02 23:13:22,758 - INFO - Guidance scale: 1
7
+ 2025-03-02 23:13:22,758 - INFO - Use LCM: True
8
+
9
+ cudnn benchmark... doesnt help either
Image-Morpher/logs/execution_20250303_000722.log ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-03 00:08:10,065 - INFO - Total execution time: 47.30 seconds
2
+ 2025-03-03 00:08:10,065 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-03 00:08:10,065 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-03 00:08:10,065 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-03 00:08:10,065 - INFO - Use LCM: True
6
+ 2025-03-03 00:08:10,065 - INFO - Number of inference steps: 8
7
+ 2025-03-03 00:08:10,065 - INFO - Guidance scale: 1
8
+
9
+ was with the other torch flags, but i think loras were loaded in alr
10
+
11
+
Image-Morpher/logs/execution_20250303_150757.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-03 15:10:09,361 - INFO - Total execution time: 131.64 seconds
2
+ 2025-03-03 15:10:09,362 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-03 15:10:09,362 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-03 15:10:09,362 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-03 15:10:09,362 - INFO - Use LCM: True
6
+ 2025-03-03 15:10:09,362 - INFO - Number of inference steps: 8
7
+ 2025-03-03 15:10:09,362 - INFO - Guidance scale: 1
8
+
9
+ this was with the flags, no diff.
Image-Morpher/logs/execution_20250309_171159.log ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:14:01,299 - INFO - Total execution time: 121.91 seconds
2
+ 2025-03-09 17:14:01,299 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-09 17:14:01,299 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 17:14:01,299 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 17:14:01,299 - INFO - Use LCM: True
6
+ 2025-03-09 17:14:01,299 - INFO - Number of inference steps: 8
7
+ 2025-03-09 17:14:01,299 - INFO - Guidance scale: 1
8
+
9
+
10
+ pipeline.enable_attention_slicing(1)
11
+
12
+ extreme memory saving, not required tbh.
Image-Morpher/logs/execution_20250309_172206.log ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:28:03,402 - INFO - Total execution time: 357.21 seconds
2
+ 2025-03-09 17:28:03,402 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-09 17:28:03,402 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 17:28:03,402 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 17:28:03,403 - INFO - Use LCM: True
6
+ 2025-03-09 17:28:03,403 - INFO - Number of inference steps: 8
7
+ 2025-03-09 17:28:03,403 - INFO - Guidance scale: 1
8
+
9
+ pipeline.unet = torch.compile(pipeline.unet) # Optimize further
10
+ torch.cuda.empty_cache()
11
+
12
+ on args.use_lcm
Image-Morpher/logs/execution_20250309_173119.log ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:35:58,495 - INFO - Total execution time: 279.30 seconds
2
+ 2025-03-09 17:35:58,495 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-09 17:35:58,495 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 17:35:58,495 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 17:35:58,495 - INFO - Use LCM: True
6
+ 2025-03-09 17:35:58,495 - INFO - Number of inference steps: 8
7
+ 2025-03-09 17:35:58,496 - INFO - Guidance scale: 1
8
+
9
+ pipeline.unet = torch.compile(pipeline.unet) # Optimize further
10
+ torch.cuda.empty_cache()
11
+
12
+ after the initial pipeline loaded in. DiffMorpherPipeline
Image-Morpher/logs/execution_20250309_173628.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:41:02,525 - INFO - Total execution time: 273.81 seconds
2
+ 2025-03-09 17:41:02,526 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-09 17:41:02,526 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 17:41:02,526 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 17:41:02,526 - INFO - Use LCM: True
6
+ 2025-03-09 17:41:02,526 - INFO - Number of inference steps: 8
7
+ 2025-03-09 17:41:02,526 - INFO - Guidance scale: 1
8
+
9
+
10
+ pipeline.unet = torch.compile(pipeline.unet) # Optimize further
11
+ # torch.cuda.empty_cache()
12
+
13
+ without empty_cache, no help.
Image-Morpher/logs/execution_20250309_175711.log ADDED
@@ -0,0 +1 @@
 
 
1
+ 2025-03-09 17:57:11,357 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
Image-Morpher/logs/execution_20250309_184725.log ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 18:49:26,905 - INFO - Total execution time: 121.72 seconds
2
+ 2025-03-09 18:49:26,906 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-09 18:49:26,906 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 18:49:26,906 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 18:49:26,906 - INFO - Use LCM: True
6
+ 2025-03-09 18:49:26,906 - INFO - Number of inference steps: 8
7
+ 2025-03-09 18:49:26,906 - INFO - Guidance scale: 1
8
+
9
+
10
+ yup sdv1-5 slower and worse results (biden dun have teeth)
Image-Morpher/logs/execution_20250309_215757.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 21:57:58,103 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+ 2025-03-09 22:01:16,108 - INFO - Total execution time: 198.35 seconds
3
+ 2025-03-09 22:01:16,108 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
4
+ 2025-03-09 22:01:16,108 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 22:01:16,108 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 22:01:16,108 - INFO - Use LCM: False
7
+ 2025-03-09 22:01:16,108 - INFO - Number of inference steps: 50
8
+ 2025-03-09 22:01:16,108 - INFO - Guidance scale: 1
9
+
10
+
11
+ So this was with LCM false and the optimisations for Flags enabled.
12
+ inductor.config.conv_1x1_as_mm = True
13
+ inductor.config.coordinate_descent_tuning = True
14
+ inductor.config.epilogue_fusion = False
15
+ inductor.config.coordinate_descent_check_all_directions = True
Image-Morpher/logs/execution_20250309_220805.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 22:08:06,169 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+ 2025-03-09 22:11:24,031 - INFO - Total execution time: 198.23 seconds
3
+ 2025-03-09 22:11:24,031 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
4
+ 2025-03-09 22:11:24,031 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 22:11:24,031 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 22:11:24,031 - INFO - Use LCM: False
7
+ 2025-03-09 22:11:24,031 - INFO - Number of inference steps: 50
8
+ 2025-03-09 22:11:24,031 - INFO - Guidance scale: 1
9
+
10
+
11
+ inductor.config.conv_1x1_as_mm = True
12
+ inductor.config.coordinate_descent_tuning = True
13
+ inductor.config.epilogue_fusion = False
14
+ inductor.config.coordinate_descent_check_all_directions = True
15
+
16
+ yeah this makes no difference at all.
Image-Morpher/logs/execution_20250309_221759.log ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 22:17:59,860 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+
3
+ torch.compile() doesnt really work, the lora training is very slow.
4
+
5
+ change:
6
+
7
+ def compile_components(pipe):
8
+ # Compile with maximum optimization
9
+ pipe.unet = torch.compile(
10
+ pipe.unet,
11
+ mode="max-autotune",
12
+ fullgraph=True,
13
+ dynamic=False
14
+ )
15
+
16
+ # Special compilation for VAE decoder
17
+ pipe.vae.decode = torch.compile(
18
+ pipe.vae.decode,
19
+ mode="reduce-overhead",
20
+ fullgraph=True
21
+ )
22
+
23
+ return pipe
24
+
25
+ if not args.use_lcm: # LCM may conflict with full compilation
26
+ pipeline = compile_components(pipeline)
Image-Morpher/logs/execution_20250309_224623.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 22:46:23,911 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+ 2025-03-09 22:48:08,545 - INFO - Total execution time: 104.98 seconds
3
+ 2025-03-09 22:48:08,545 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
4
+ 2025-03-09 22:48:08,545 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 22:48:08,545 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 22:48:08,545 - INFO - Use LCM: True
7
+ 2025-03-09 22:48:08,545 - INFO - Number of inference steps: 8
8
+ 2025-03-09 22:48:08,545 - INFO - Guidance scale: 1
9
+
10
+
11
+ trying some lcm lora optimisations, this does not seem to help too much ? its with
12
+
13
+ pipeline.scheduler = LCMScheduler.from_config(
14
+ pipeline.scheduler.config,
15
+ timestep_spacing="trailing", # Optimized for LCM
16
+ prediction_type = "epsilion" # this is default anyway.
17
+ )
Image-Morpher/logs/execution_20250309_224915.log ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 22:49:15,959 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main', switching to standard 2-1 base.
2
+ 2025-03-09 22:49:50,993 - INFO - Total execution time: 35.29 seconds
3
+ 2025-03-09 22:49:50,993 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
4
+ 2025-03-09 22:49:50,993 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 22:49:50,993 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 22:49:50,993 - INFO - Use LCM: True
7
+ 2025-03-09 22:49:50,993 - INFO - Number of inference steps: 8
8
+ 2025-03-09 22:49:50,993 - INFO - Guidance scale: 1
9
+
10
+ fail with v_prediction prediction_type
Image-Morpher/logs/execution_20250309_225023.log ADDED
Binary file (747 Bytes). View file
 
Image-Morpher/logs/execution_20250309_225432.log ADDED
Binary file (1.11 kB). View file
 
Image-Morpher/logs/execution_20250309_232833.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 23:30:11,054 - INFO - Total execution time: 97.80 seconds
2
+ 2025-03-09 23:30:11,054 - INFO - Model Path: stabilityai/stable-diffusion-2-1-base
3
+ 2025-03-09 23:30:11,054 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 23:30:11,054 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 23:30:11,055 - INFO - Use LCM: True
6
+ 2025-03-09 23:30:11,055 - INFO - Number of inference steps: 8
7
+ 2025-03-09 23:30:11,055 - INFO - Guidance scale: 1
8
+
9
+
10
+ so this was without the
11
+
12
+ # Configure compiler settings (safe defaults)
13
+ # inductor.config.force_fuse_int_mm_with_mul = True # Better for diffusion models
14
+
15
+ it was 0.7 seconds off, so the real benefit (of the real cudnn benchmark OR the set_float32_matmul_precision actually!!, u didnt realise bef) was
16
+ shifting the cudnn benchmark test AFTER pipeline.to("cuda")
17
+
Image-Morpher/logs/execution_20250309_235119.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 23:53:03,895 - INFO - Total execution time: 104.19 seconds
2
+ 2025-03-09 23:53:03,895 - INFO - Model Path: stabilityai/stable-diffusion-2-1-base
3
+ 2025-03-09 23:53:03,896 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 23:53:03,896 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 23:53:03,896 - INFO - Use LCM: True
6
+ 2025-03-09 23:53:03,896 - INFO - Number of inference steps: 8
7
+ 2025-03-09 23:53:03,896 - INFO - Guidance scale: 1
8
+
9
+
10
+ this was base, with the
11
+
12
+ torch.backends.cudnn.benchmark = True
13
+ torch.set_float32_matmul_precision("high") # Better for modern GPUs
14
+
15
+ commented OUT. Hence the previous does make a difference.
16
+
Image-Morpher/logs/execution_20250309_235339.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 23:55:24,160 - INFO - Total execution time: 104.42 seconds
2
+ 2025-03-09 23:55:24,160 - INFO - Model Path: stabilityai/stable-diffusion-2-1-base
3
+ 2025-03-09 23:55:24,160 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 23:55:24,160 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 23:55:24,160 - INFO - Use LCM: True
6
+ 2025-03-09 23:55:24,160 - INFO - Number of inference steps: 8
7
+ 2025-03-09 23:55:24,160 - INFO - Guidance scale: 1
8
+
9
+
10
+ this was with
11
+
12
+ torch.backends.cudnn.benchmark = True
13
+
14
+ not much speed up! infact its a bit slower (sometimes its slightly tiny bit faster)
15
+
Image-Morpher/logs/slight_saving_cudnn_benchmark.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-02 22:57:31,314 - INFO - Total execution time: 123.84 seconds
2
+ 2025-03-02 22:57:31,314 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-02 22:57:31,314 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-02 22:57:31,314 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-02 22:57:31,315 - INFO - Number of inference steps: 8
6
+ 2025-03-02 22:57:31,315 - INFO - Guidance scale: 1
7
+ 2025-03-02 22:57:31,315 - INFO - Use LCM: True
8
+
9
+ cudnn benchmark
Image-Morpher/logs/strange_test.log ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:45:32,554 - ERROR - Model loading failed: [Errno 2] No such file or directory: '/mnt/slurm_home/nalin/.cache/huggingface/hub/models--DarkFlameUniverse--Stable-Diffusion-2-1-Base-8bit/refs/main'
2
+ 2025-03-09 17:47:22,395 - INFO - Total execution time: 110.10 seconds
3
+ 2025-03-09 17:47:22,396 - INFO - Model Path: DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit # i think this was NOT the model path that was actually used.
4
+ 2025-03-09 17:47:22,396 - INFO - Image Path 0: ./assets/Trump.jpg
5
+ 2025-03-09 17:47:22,396 - INFO - Image Path 1: ./assets/Biden.jpg
6
+ 2025-03-09 17:47:22,396 - INFO - Use LCM: True
7
+ 2025-03-09 17:47:22,396 - INFO - Number of inference steps: 8
8
+ 2025-03-09 17:47:22,396 - INFO - Guidance scale: 1
9
+
10
+
11
+ so it worked with the quantized 8-bit model (no it didnt) ! but... it was not able to load it in at first, then had to use my local cache of the failed run i tried befo i think (nope)
12
+
Image-Morpher/logs/successful_memory_optimisation1.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-09 17:01:38,747 - INFO - Total execution time: 121.97 seconds
2
+ 2025-03-09 17:01:38,747 - INFO - Model Path: sd-legacy/stable-diffusion-v1-5
3
+ 2025-03-09 17:01:38,747 - INFO - Image Path 0: ./assets/Trump.jpg
4
+ 2025-03-09 17:01:38,747 - INFO - Image Path 1: ./assets/Biden.jpg
5
+ 2025-03-09 17:01:38,747 - INFO - Use LCM: True
6
+ 2025-03-09 17:01:38,747 - INFO - Number of inference steps: 8
7
+ 2025-03-09 17:01:38,747 - INFO - Guidance scale: 1
8
+
9
+
10
+ fastest one yet,
11
+ with optimisation
12
+ pipeline.enable_vae_slicing()
13
+ pipeline.enable_attention_slicing()
14
+
15
+ about 5-8 seconds saved on average.
16
+
17
+ and the cudnn benchmark MAYBE helps but not really.
Image-Morpher/main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified main.py on DiffMorpher, LCM-LoRA support + param additions + logging + optimizations for speed-up
3
+ """
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from argparse import ArgumentParser
9
+ from model import DiffMorpherPipeline
10
+ import time
11
+ import logging
12
+
13
+ logs_folder = "logs"
14
+ os.makedirs(logs_folder, exist_ok=True)
15
+
16
+ # Create a unique log filename using the current time
17
+ log_filename = os.path.join(logs_folder, f"execution_{time.strftime('%Y%m%d_%H%M%S')}.log")
18
+ logging.basicConfig(
19
+ filename=log_filename,
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(levelname)s - %(message)s'
22
+ )
23
+
24
+ start_time = time.time()
25
+
26
+ parser = ArgumentParser()
27
+ parser.add_argument(
28
+ "--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base",
29
+ help="Pretrained model to use (default: %(default)s)"
30
+ )
31
+ # Available SDV1-5 versions:
32
+ # sd-legacy/stable-diffusion-v1-5
33
+ # lykon/dreamshaper-7
34
+
35
+ # Original DiffMorpher SD:
36
+ # stabilityai/stable-diffusion-2-1-base
37
+
38
+ # Quantized models to try (non-functional, possible extension for future)
39
+ # DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
40
+ # Xerox32/SD2.1-base-Int8
41
+
42
+ parser.add_argument(
43
+ "--image_path_0", type=str, default="",
44
+ help="Path of the first image (default: %(default)s)"
45
+ )
46
+ parser.add_argument(
47
+ "--prompt_0", type=str, default="",
48
+ help="Prompt of the first image (default: %(default)s)"
49
+ )
50
+ parser.add_argument(
51
+ "--image_path_1", type=str, default="",
52
+ help="Path of the second image (default: %(default)s)"
53
+ )
54
+ parser.add_argument(
55
+ "--prompt_1", type=str, default="",
56
+ help="Prompt of the second image (default: %(default)s)"
57
+ )
58
+ parser.add_argument(
59
+ "--output_path", type=str, default="./results",
60
+ help="Path of the output image (default: %(default)s)"
61
+ )
62
+ parser.add_argument(
63
+ "--save_lora_dir", type=str, default="./lora",
64
+ help="Path for saving LoRA weights (default: %(default)s)"
65
+ )
66
+ parser.add_argument(
67
+ "--load_lora_path_0", type=str, default="",
68
+ help="Path of the LoRA weights for the first image (default: %(default)s)"
69
+ )
70
+ parser.add_argument(
71
+ "--load_lora_path_1", type=str, default="",
72
+ help="Path of the LoRA weights for the second image (default: %(default)s)"
73
+ )
74
+ parser.add_argument(
75
+ "--num_inference_steps", type=int, default=50,
76
+ help="Number of inference steps (default: %(default)s)")
77
+ parser.add_argument(
78
+ "--guidance_scale", type=float, default=1, # To match current diffmorpher
79
+ help="Guidance scale for classifier-free guidance (default: %(default)s)"
80
+ )
81
+
82
+ parser.add_argument("--use_adain", action="store_true", help="Use AdaIN (default: %(default)s)")
83
+ parser.add_argument("--use_reschedule", action="store_true", help="Use reschedule sampling (default: %(default)s)")
84
+ parser.add_argument("--lamb", type=float, default=0.6, help="Lambda for self-attention replacement (default: %(default)s)")
85
+ parser.add_argument("--fix_lora_value", type=float, default=None, help="Fix lora value (default: LoRA Interp., not fixed)")
86
+ parser.add_argument("--save_inter", action="store_true", help="Save intermediate results (default: %(default)s)")
87
+ parser.add_argument("--num_frames", type=int, default=16, help="Number of frames to generate (default: %(default)s)")
88
+ parser.add_argument("--duration", type=int, default=100, help="Duration of each frame (default: %(default)s ms)")
89
+ parser.add_argument("--no_lora", action="store_true", help="Disable style LoRA (default: %(default)s)")
90
+
91
+ # New argument for LCM LoRA acceleration
92
+ parser.add_argument("--use_lcm", action="store_true", help="Enable LCM-LoRA acceleration for faster sampling")
93
+
94
+ args = parser.parse_args()
95
+ os.makedirs(args.output_path, exist_ok=True)
96
+
97
+ # Create the pipeline from the given model path
98
+ pipeline = DiffMorpherPipeline.from_pretrained(args.model_path, torch_dtype=torch.float32)
99
+
100
+ # memory optimisations for vae and attention slicing - breaks computations into smaller chunks to fit better in mem
101
+ # can lead to more efficient caching and memory access. better memory locality
102
+ # found that its helpful with GPUs with limited VRAM memory in particular.
103
+ pipeline.enable_vae_slicing()
104
+ pipeline.enable_attention_slicing()
105
+
106
+ pipeline.to("cuda")
107
+
108
+ # Add these AFTER device movement
109
+ torch.backends.cudnn.benchmark = True # finds efficient convolution algo by running short benchmark, minimal speed-up.
110
+ torch.set_float32_matmul_precision("high") # Better for modern GPUs, reduces about 7 seconds of inference time.
111
+
112
+ # Integrate LCM-LoRA if flagged, OUTSIDE any of the style LoRA loading / training steps.
113
+ if args.use_lcm:
114
+ from lcm_lora.lcm_schedule import LCMScheduler
115
+ # Replace scheduler using LCM's configuration
116
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
117
+ # Load the LCM LoRA weights (LCM provides an add-on network)
118
+ pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
119
+ # Set the lcm_inference_steps
120
+ args.num_inference_steps = 8 # Override with LCM-recommended steps
121
+ # set CFG (range allowed by legacy code: 0 to 1, 1 performs best)
122
+ args.guidance_scale = 1
123
+
124
+ # Run the pipeline inference using existing parameters
125
+ images = pipeline(
126
+ img_path_0=args.image_path_0,
127
+ img_path_1=args.image_path_1,
128
+ prompt_0=args.prompt_0,
129
+ prompt_1=args.prompt_1,
130
+ save_lora_dir=args.save_lora_dir,
131
+ load_lora_path_0=args.load_lora_path_0,
132
+ load_lora_path_1=args.load_lora_path_1,
133
+ use_adain=args.use_adain,
134
+ use_reschedule=args.use_reschedule,
135
+ lamd=args.lamb,
136
+ output_path=args.output_path,
137
+ num_frames=args.num_frames,
138
+ num_inference_steps = args.num_inference_steps, # enforce when LCM enabled
139
+ fix_lora=args.fix_lora_value,
140
+ save_intermediates=args.save_inter,
141
+ use_lora=not args.no_lora,
142
+ use_lcm = args.use_lcm,
143
+ guidance_scale=args.guidance_scale, # enforce when LCM enabled
144
+ )
145
+
146
+ # Save the resulting GIF output from the sequence of images
147
+ images[0].save(f"{args.output_path}/output.gif", save_all=True,
148
+ append_images=images[1:], duration=args.duration, loop=0)
149
+
150
+ end_time = time.time()
151
+ elapsed_time = end_time - start_time
152
+
153
+ # Log the execution details and parameters
154
+ logging.info(f"Total execution time: {elapsed_time:.2f} seconds")
155
+ logging.info(f"Model Path: {args.model_path}")
156
+ logging.info(f"Image Path 0: {args.image_path_0}")
157
+ logging.info(f"Image Path 1: {args.image_path_1}")
158
+ logging.info(f"Use LCM: {args.use_lcm}")
159
+ logging.info(f"Number of inference steps: {args.num_inference_steps}")
160
+ logging.info(f"Guidance scale: {args.guidance_scale}")
161
+
162
+ print(f"Total execution time: {elapsed_time:.2f} seconds, log file saved as {log_filename}")
Image-Morpher/model.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import os
3
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
4
+ from diffusers.models.attention_processor import AttnProcessor
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ from diffusers.schedulers import KarrasDiffusionSchedulers
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import tqdm
10
+ import numpy as np
11
+ import safetensors
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
15
+ from diffusers import StableDiffusionPipeline
16
+ from argparse import ArgumentParser
17
+ import inspect
18
+
19
+ from utils.model_utils import get_img, slerp, do_replace_attn
20
+ from utils.lora_utils import train_lora, load_lora
21
+ from utils.alpha_scheduler import AlphaScheduler
22
+
23
+
24
+ class StoreProcessor():
25
+ def __init__(self, original_processor, value_dict, name):
26
+ self.original_processor = original_processor
27
+ self.value_dict = value_dict
28
+ self.name = name
29
+ self.value_dict[self.name] = dict()
30
+ self.id = 0
31
+
32
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
33
+ # Is self attention
34
+ if encoder_hidden_states is None:
35
+ self.value_dict[self.name][self.id] = hidden_states.detach()
36
+ self.id += 1
37
+ res = self.original_processor(attn, hidden_states, *args,
38
+ encoder_hidden_states=encoder_hidden_states,
39
+ attention_mask=attention_mask,
40
+ **kwargs)
41
+
42
+ return res
43
+
44
+
45
+ class LoadProcessor():
46
+ def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6):
47
+ super().__init__()
48
+ self.original_processor = original_processor
49
+ self.name = name
50
+ self.img0_dict = img0_dict
51
+ self.img1_dict = img1_dict
52
+ self.alpha = alpha
53
+ self.beta = beta
54
+ self.lamd = lamd
55
+ self.id = 0
56
+
57
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
58
+ # Is self attention
59
+ if encoder_hidden_states is None:
60
+ if self.id < 50 * self.lamd:
61
+ map0 = self.img0_dict[self.name][self.id]
62
+ map1 = self.img1_dict[self.name][self.id]
63
+ cross_map = self.beta * hidden_states + \
64
+ (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
65
+ # cross_map = self.beta * hidden_states + \
66
+ # (1 - self.beta) * slerp(map0, map1, self.alpha)
67
+ # cross_map = slerp(slerp(map0, map1, self.alpha),
68
+ # hidden_states, self.beta)
69
+ # cross_map = hidden_states
70
+ # cross_map = torch.cat(
71
+ # ((1 - self.alpha) * map0, self.alpha * map1), dim=1)
72
+
73
+ res = self.original_processor(attn, hidden_states, *args,
74
+ encoder_hidden_states=cross_map,
75
+ attention_mask=attention_mask,
76
+ **kwargs)
77
+ else:
78
+ res = self.original_processor(attn, hidden_states, *args,
79
+ encoder_hidden_states=encoder_hidden_states,
80
+ attention_mask=attention_mask,
81
+ **kwargs)
82
+
83
+ self.id += 1
84
+ # if self.id == len(self.img0_dict[self.name]):
85
+ if self.id == len(self.img0_dict[self.name]):
86
+ self.id = 0
87
+ else:
88
+ res = self.original_processor(attn, hidden_states, *args,
89
+ encoder_hidden_states=encoder_hidden_states,
90
+ attention_mask=attention_mask,
91
+ **kwargs)
92
+
93
+ return res
94
+
95
+
96
+ class DiffMorpherPipeline(StableDiffusionPipeline):
97
+
98
+ def __init__(self,
99
+ vae: AutoencoderKL,
100
+ text_encoder: CLIPTextModel,
101
+ tokenizer: CLIPTokenizer,
102
+ unet: UNet2DConditionModel,
103
+ scheduler: KarrasDiffusionSchedulers,
104
+ safety_checker: StableDiffusionSafetyChecker,
105
+ feature_extractor: CLIPImageProcessor,
106
+ image_encoder=None,
107
+ requires_safety_checker: bool = True,
108
+ ):
109
+ sig = inspect.signature(super().__init__)
110
+ params = sig.parameters
111
+ if 'image_encoder' in params:
112
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
113
+ safety_checker, feature_extractor, image_encoder, requires_safety_checker)
114
+ else:
115
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
116
+ safety_checker, feature_extractor, requires_safety_checker)
117
+ self.img0_dict = dict()
118
+ self.img1_dict = dict()
119
+
120
+ def inv_step(
121
+ self,
122
+ model_output: torch.FloatTensor,
123
+ timestep: int,
124
+ x: torch.FloatTensor,
125
+ eta=0.,
126
+ verbose=False
127
+ ):
128
+ """
129
+ Inverse sampling for DDIM Inversion
130
+ """
131
+ if verbose:
132
+ print("timestep: ", timestep)
133
+ next_step = timestep
134
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps //
135
+ self.scheduler.num_inference_steps, 999)
136
+ alpha_prod_t = self.scheduler.alphas_cumprod[
137
+ timestep] if timestep >= 0 else self.scheduler.alphas_cumprod[0]
138
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
139
+ beta_prod_t = 1 - alpha_prod_t
140
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
141
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
142
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
143
+ return x_next, pred_x0
144
+
145
+ @torch.no_grad()
146
+ def invert(
147
+ self,
148
+ image: torch.Tensor,
149
+ prompt,
150
+ num_inference_steps=50,
151
+ num_actual_inference_steps=None,
152
+ guidance_scale=1.,
153
+ eta=0.0,
154
+ **kwds):
155
+ """
156
+ invert a real image into noise map with determinisc DDIM inversion
157
+ """
158
+ DEVICE = torch.device(
159
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
160
+ batch_size = image.shape[0]
161
+ if isinstance(prompt, list):
162
+ if batch_size == 1:
163
+ image = image.expand(len(prompt), -1, -1, -1)
164
+ elif isinstance(prompt, str):
165
+ if batch_size > 1:
166
+ prompt = [prompt] * batch_size
167
+
168
+ # text embeddings
169
+ text_input = self.tokenizer(
170
+ prompt,
171
+ padding="max_length",
172
+ max_length=77,
173
+ return_tensors="pt"
174
+ )
175
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
176
+ print("input text embeddings :", text_embeddings.shape)
177
+ # define initial latents
178
+ latents = self.image2latent(image)
179
+
180
+ # unconditional embedding for classifier free guidance
181
+ if guidance_scale > 1.:
182
+ max_length = text_input.input_ids.shape[-1]
183
+ unconditional_input = self.tokenizer(
184
+ [""] * batch_size,
185
+ padding="max_length",
186
+ max_length=77,
187
+ return_tensors="pt"
188
+ )
189
+ unconditional_embeddings = self.text_encoder(
190
+ unconditional_input.input_ids.to(DEVICE))[0]
191
+ text_embeddings = torch.cat(
192
+ [unconditional_embeddings, text_embeddings], dim=0)
193
+
194
+ print("latents shape: ", latents.shape)
195
+ # interative sampling
196
+ self.scheduler.set_timesteps(num_inference_steps)
197
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
198
+ # print("attributes: ", self.scheduler.__dict__)
199
+ latents_list = [latents]
200
+ pred_x0_list = [latents]
201
+ for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
202
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
203
+ continue
204
+
205
+ if guidance_scale > 1.:
206
+ model_inputs = torch.cat([latents] * 2)
207
+ else:
208
+ model_inputs = latents
209
+
210
+ # predict the noise
211
+ noise_pred = self.unet(
212
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
213
+ if guidance_scale > 1.:
214
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
215
+ noise_pred = noise_pred_uncon + guidance_scale * \
216
+ (noise_pred_con - noise_pred_uncon)
217
+ # compute the previous noise sample x_t-1 -> x_t
218
+ latents, pred_x0 = self.inv_step(noise_pred, t, latents)
219
+ latents_list.append(latents)
220
+ pred_x0_list.append(pred_x0)
221
+
222
+ return latents
223
+
224
+ @torch.no_grad()
225
+ def ddim_inversion(self, latent, cond):
226
+ timesteps = reversed(self.scheduler.timesteps)
227
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
228
+ for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")):
229
+ cond_batch = cond.repeat(latent.shape[0], 1, 1)
230
+
231
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
232
+ alpha_prod_t_prev = (
233
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
234
+ if i > 0 else self.scheduler.alphas_cumprod[0]
235
+ )
236
+
237
+ mu = alpha_prod_t ** 0.5
238
+ mu_prev = alpha_prod_t_prev ** 0.5
239
+ sigma = (1 - alpha_prod_t) ** 0.5
240
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
241
+
242
+ eps = self.unet(
243
+ latent, t, encoder_hidden_states=cond_batch).sample
244
+
245
+ pred_x0 = (latent - sigma_prev * eps) / mu_prev
246
+ latent = mu * pred_x0 + sigma * eps
247
+ # if save_latents:
248
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
249
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
250
+ return latent
251
+
252
+ def step(
253
+ self,
254
+ model_output: torch.FloatTensor,
255
+ timestep: int,
256
+ x: torch.FloatTensor,
257
+ ):
258
+ """
259
+ predict the sample of the next step in the denoise process.
260
+ """
261
+ prev_timestep = timestep - \
262
+ self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
263
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
264
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
265
+ prev_timestep] if prev_timestep > 0 else self.scheduler.alphas_cumprod[0]
266
+ beta_prod_t = 1 - alpha_prod_t
267
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
268
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
269
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
270
+ return x_prev, pred_x0
271
+
272
+ @torch.no_grad()
273
+ def image2latent(self, image):
274
+ DEVICE = torch.device(
275
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
276
+ if type(image) is Image:
277
+ image = np.array(image)
278
+ image = torch.from_numpy(image).float() / 127.5 - 1
279
+ image = image.permute(2, 0, 1).unsqueeze(0)
280
+ # input image density range [-1, 1]
281
+ latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean
282
+ latents = latents * 0.18215
283
+ return latents
284
+
285
+ @torch.no_grad()
286
+ def latent2image(self, latents, return_type='np'):
287
+ latents = 1 / 0.18215 * latents.detach()
288
+ image = self.vae.decode(latents)['sample']
289
+ if return_type == 'np':
290
+ image = (image / 2 + 0.5).clamp(0, 1)
291
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
292
+ image = (image * 255).astype(np.uint8)
293
+ elif return_type == "pt":
294
+ image = (image / 2 + 0.5).clamp(0, 1)
295
+
296
+ return image
297
+
298
+ def latent2image_grad(self, latents):
299
+ latents = 1 / 0.18215 * latents
300
+ image = self.vae.decode(latents)['sample']
301
+
302
+ return image # range [-1, 1]
303
+
304
+ @torch.no_grad()
305
+ def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, use_lcm, fix_lora=None):
306
+ # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \
307
+ # torch.sin(alpha * torch.pi / 2) * img_noise_1
308
+ # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1
309
+ # latents = latents / ((1 - alpha) ** 2 + alpha ** 2)
310
+ latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain)
311
+ text_embeddings = (1 - alpha) * text_embeddings_0 + \
312
+ alpha * text_embeddings_1
313
+
314
+ self.scheduler.set_timesteps(num_inference_steps)
315
+ if use_lora:
316
+ if fix_lora is not None:
317
+ self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora)
318
+ else:
319
+ self.unet = load_lora(self.unet, lora_0, lora_1, alpha)
320
+
321
+ if use_lcm:
322
+ sampler_desc = "LCM multi-step sampler"
323
+ else:
324
+ sampler_desc = "DDIM Sampler" # currently defaults to this
325
+
326
+ for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"{sampler_desc}, alpha={alpha}")):
327
+
328
+ if guidance_scale > 1.:
329
+ model_inputs = torch.cat([latents] * 2)
330
+ else:
331
+ model_inputs = latents
332
+ if unconditioning is not None and isinstance(unconditioning, list):
333
+ _, text_embeddings = text_embeddings.chunk(2)
334
+ text_embeddings = torch.cat(
335
+ [unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
336
+ # predict the noise
337
+ noise_pred = self.unet(
338
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
339
+ if guidance_scale > 1.0:
340
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(
341
+ 2, dim=0)
342
+ noise_pred = noise_pred_uncon + guidance_scale * \
343
+ (noise_pred_con - noise_pred_uncon)
344
+ # compute the previous noise sample x_t -> x_t-1
345
+ latents = self.scheduler.step(
346
+ noise_pred, t, latents, return_dict=False)[0]
347
+ return latents
348
+
349
+ @torch.no_grad()
350
+ def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size):
351
+ DEVICE = torch.device(
352
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
353
+ # text embeddings
354
+ text_input = self.tokenizer(
355
+ prompt,
356
+ padding="max_length",
357
+ max_length=77,
358
+ return_tensors="pt"
359
+ )
360
+ text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
361
+
362
+ if guidance_scale > 1.:
363
+ if neg_prompt:
364
+ uc_text = neg_prompt
365
+ else:
366
+ uc_text = ""
367
+ unconditional_input = self.tokenizer(
368
+ [uc_text] * batch_size,
369
+ padding="max_length",
370
+ max_length=77,
371
+ return_tensors="pt"
372
+ )
373
+ unconditional_embeddings = self.text_encoder(
374
+ unconditional_input.input_ids.to(DEVICE))[0]
375
+ text_embeddings = torch.cat(
376
+ [unconditional_embeddings, text_embeddings], dim=0)
377
+
378
+ return text_embeddings
379
+
380
+ def __call__(
381
+ self,
382
+ img_0=None,
383
+ img_1=None,
384
+ img_path_0=None,
385
+ img_path_1=None,
386
+ prompt_0="",
387
+ prompt_1="",
388
+ save_lora_dir="./lora",
389
+ load_lora_path_0=None,
390
+ load_lora_path_1=None,
391
+ lora_steps=200,
392
+ lora_lr=2e-4,
393
+ lora_rank=16,
394
+ batch_size=1,
395
+ height=512,
396
+ width=512,
397
+ num_inference_steps=50,
398
+ num_actual_inference_steps=None,
399
+ guidance_scale=1,
400
+ attn_beta=0,
401
+ lamd=0.6,
402
+ use_lora=True,
403
+ use_lcm = False,
404
+ use_adain=True,
405
+ use_reschedule=True,
406
+ output_path="./results",
407
+ num_frames=50,
408
+ fix_lora=None,
409
+ progress=tqdm,
410
+ unconditioning=None,
411
+ neg_prompt=None,
412
+ save_intermediates=False,
413
+ **kwds):
414
+
415
+ # if isinstance(prompt, list):
416
+ # batch_size = len(prompt)
417
+ # elif isinstance(prompt, str):
418
+ # if batch_size > 1:
419
+ # prompt = [prompt] * batch_size
420
+ self.scheduler.set_timesteps(num_inference_steps)
421
+ self.use_lora = use_lora
422
+ self.use_adain = use_adain
423
+ self.use_reschedule = use_reschedule
424
+ self.output_path = output_path
425
+ self.use_lcm = use_lcm
426
+
427
+ if img_0 is None:
428
+ img_0 = Image.open(img_path_0).convert("RGB")
429
+ # else:
430
+ # img_0 = Image.fromarray(img_0).convert("RGB")
431
+
432
+ if img_1 is None:
433
+ img_1 = Image.open(img_path_1).convert("RGB")
434
+ # else:
435
+ # img_1 = Image.fromarray(img_1).convert("RGB")
436
+
437
+ if self.use_lora:
438
+ print("Loading lora...")
439
+ if not load_lora_path_0:
440
+
441
+ weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
442
+ load_lora_path_0 = save_lora_dir + "/" + weight_name
443
+ if not os.path.exists(load_lora_path_0):
444
+ train_lora(img_0, prompt_0, save_lora_dir, None, self.tokenizer, self.text_encoder,
445
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
446
+ print(f"Load from {load_lora_path_0}.")
447
+ if load_lora_path_0.endswith(".safetensors"):
448
+ lora_0 = safetensors.torch.load_file(
449
+ load_lora_path_0, device="cpu")
450
+ else:
451
+ lora_0 = torch.load(load_lora_path_0, map_location="cpu")
452
+
453
+ if not load_lora_path_1:
454
+ weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
455
+ load_lora_path_1 = save_lora_dir + "/" + weight_name
456
+ if not os.path.exists(load_lora_path_1):
457
+ train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder,
458
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
459
+ print(f"Load from {load_lora_path_1}.")
460
+ if load_lora_path_1.endswith(".safetensors"):
461
+ lora_1 = safetensors.torch.load_file(
462
+ load_lora_path_1, device="cpu")
463
+ else:
464
+ lora_1 = torch.load(load_lora_path_1, map_location="cpu")
465
+ else:
466
+ lora_0 = lora_1 = None
467
+
468
+ text_embeddings_0 = self.get_text_embeddings(
469
+ prompt_0, guidance_scale, neg_prompt, batch_size)
470
+ text_embeddings_1 = self.get_text_embeddings(
471
+ prompt_1, guidance_scale, neg_prompt, batch_size)
472
+ img_0 = get_img(img_0)
473
+ img_1 = get_img(img_1)
474
+ if self.use_lora:
475
+ self.unet = load_lora(self.unet, lora_0, lora_1, 0)
476
+ img_noise_0 = self.ddim_inversion(
477
+ self.image2latent(img_0), text_embeddings_0)
478
+ if self.use_lora:
479
+ self.unet = load_lora(self.unet, lora_0, lora_1, 1)
480
+ img_noise_1 = self.ddim_inversion(
481
+ self.image2latent(img_1), text_embeddings_1)
482
+
483
+ print("latents shape: ", img_noise_0.shape)
484
+
485
+ original_processor = list(self.unet.attn_processors.values())[0]
486
+
487
+ def morph(alpha_list, progress, desc):
488
+ images = []
489
+ if attn_beta is not None:
490
+ if self.use_lora:
491
+ self.unet = load_lora(
492
+ self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora)
493
+
494
+ attn_processor_dict = {}
495
+ for k in self.unet.attn_processors.keys():
496
+ if do_replace_attn(k):
497
+ if self.use_lora:
498
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
499
+ self.img0_dict, k)
500
+ else:
501
+ attn_processor_dict[k] = StoreProcessor(original_processor,
502
+ self.img0_dict, k)
503
+ else:
504
+ attn_processor_dict[k] = self.unet.attn_processors[k]
505
+ self.unet.set_attn_processor(attn_processor_dict)
506
+
507
+ latents = self.cal_latent(
508
+ num_inference_steps,
509
+ guidance_scale,
510
+ unconditioning,
511
+ img_noise_0,
512
+ img_noise_1,
513
+ text_embeddings_0,
514
+ text_embeddings_1,
515
+ lora_0,
516
+ lora_1,
517
+ alpha_list[0],
518
+ False,
519
+ fix_lora
520
+ )
521
+ first_image = self.latent2image(latents)
522
+ first_image = Image.fromarray(first_image)
523
+ if save_intermediates:
524
+ first_image.save(f"{self.output_path}/{0:02d}.png")
525
+
526
+ if self.use_lora:
527
+ self.unet = load_lora(
528
+ self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora)
529
+ attn_processor_dict = {}
530
+ for k in self.unet.attn_processors.keys():
531
+ if do_replace_attn(k):
532
+ if self.use_lora:
533
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
534
+ self.img1_dict, k)
535
+ else:
536
+ attn_processor_dict[k] = StoreProcessor(original_processor,
537
+ self.img1_dict, k)
538
+ else:
539
+ attn_processor_dict[k] = self.unet.attn_processors[k]
540
+
541
+ self.unet.set_attn_processor(attn_processor_dict)
542
+
543
+ latents = self.cal_latent(
544
+ num_inference_steps,
545
+ guidance_scale,
546
+ unconditioning,
547
+ img_noise_0,
548
+ img_noise_1,
549
+ text_embeddings_0,
550
+ text_embeddings_1,
551
+ lora_0,
552
+ lora_1,
553
+ alpha_list[-1],
554
+ False,
555
+ fix_lora
556
+ )
557
+ last_image = self.latent2image(latents)
558
+ last_image = Image.fromarray(last_image)
559
+ if save_intermediates:
560
+ last_image.save(
561
+ f"{self.output_path}/{num_frames - 1:02d}.png")
562
+
563
+ for i in progress.tqdm(range(1, num_frames - 1), desc=desc):
564
+ alpha = alpha_list[i]
565
+ if self.use_lora:
566
+ self.unet = load_lora(
567
+ self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora)
568
+
569
+ attn_processor_dict = {}
570
+ for k in self.unet.attn_processors.keys():
571
+ if do_replace_attn(k):
572
+ if self.use_lora:
573
+ attn_processor_dict[k] = LoadProcessor(
574
+ self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
575
+ else:
576
+ attn_processor_dict[k] = LoadProcessor(
577
+ original_processor, k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
578
+ else:
579
+ attn_processor_dict[k] = self.unet.attn_processors[k]
580
+
581
+ self.unet.set_attn_processor(attn_processor_dict)
582
+
583
+ latents = self.cal_latent(
584
+ num_inference_steps,
585
+ guidance_scale,
586
+ unconditioning,
587
+ img_noise_0,
588
+ img_noise_1,
589
+ text_embeddings_0,
590
+ text_embeddings_1,
591
+ lora_0,
592
+ lora_1,
593
+ alpha_list[i],
594
+ False,
595
+ fix_lora
596
+ )
597
+ image = self.latent2image(latents)
598
+ image = Image.fromarray(image)
599
+ if save_intermediates:
600
+ image.save(f"{self.output_path}/{i:02d}.png")
601
+ images.append(image)
602
+
603
+ images = [first_image] + images + [last_image]
604
+
605
+ else:
606
+ for k, alpha in enumerate(alpha_list):
607
+
608
+ latents = self.cal_latent(
609
+ num_inference_steps,
610
+ guidance_scale,
611
+ unconditioning,
612
+ img_noise_0,
613
+ img_noise_1,
614
+ text_embeddings_0,
615
+ text_embeddings_1,
616
+ lora_0,
617
+ lora_1,
618
+ alpha_list[k],
619
+ self.use_lora,
620
+ fix_lora
621
+ )
622
+ image = self.latent2image(latents)
623
+ image = Image.fromarray(image)
624
+ if save_intermediates:
625
+ image.save(f"{self.output_path}/{k:02d}.png")
626
+ images.append(image)
627
+
628
+ return images
629
+
630
+ with torch.no_grad():
631
+ if self.use_reschedule:
632
+ alpha_scheduler = AlphaScheduler()
633
+ alpha_list = list(torch.linspace(0, 1, num_frames))
634
+ images_pt = morph(alpha_list, progress, "Sampling...")
635
+ images_pt = [transforms.ToTensor()(img).unsqueeze(0)
636
+ for img in images_pt]
637
+ alpha_scheduler.from_imgs(images_pt)
638
+ alpha_list = alpha_scheduler.get_list()
639
+ print(alpha_list)
640
+ images = morph(alpha_list, progress, "Reschedule..."
641
+ )
642
+ else:
643
+ alpha_list = list(torch.linspace(0, 1, num_frames))
644
+ print(alpha_list)
645
+ images = morph(alpha_list, progress, "Sampling...")
646
+
647
+ return images
Image-Morpher/requirements_diffmorpher.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ diffusers==0.17.1
3
+ einops==0.7.0
4
+ gradio==4.7.1
5
+ numpy==1.26.1
6
+ opencv_python==4.5.5.64
7
+ packaging==23.2
8
+ Pillow==10.1.0
9
+ safetensors==0.4.0
10
+ tqdm==4.65.0
11
+ transformers==4.34.1
12
+ torch
13
+ torchvision
14
+ lpips
Image-Morpher/run.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # srun -p rtx3090_slab -w slabgpu08 --gres=gpu:1 \
2
+ # --job-name=test --kill-on-bad-exit=1 python3 main.py \
3
+ # --image_path_0 ./assets/vangogh.jpg --image_path_1 ./assets/pearlgirl.jpg \
4
+ # --prompt_0 "An oil painting of a man" --prompt_1 "An oil painting of a woman" \
5
+ # --output_path "./results/vangogh_pearlgirl" --use_adain --use_reschedule \
6
+ # --save_inter --use_lcm
7
+
8
+ # srun -p rtx3090_slab -w slabgpu08 --gres=gpu:1 \
9
+ # --job-name=test --kill-on-bad-exit=1 python3 main.py \
10
+ # --image_path_0 ./assets/lion.png --image_path_1 ./assets/tiger.png \
11
+ # --prompt_0 "A photo of a lion" --prompt_1 "A photo of a tiger" \
12
+ # --output_path "./results/lion_tiger" --use_adain --use_reschedule \
13
+ # --save_inter --use_lcm
14
+
15
+ srun -p rtx3090_slab -w slabgpu05 --gres=gpu:1 \
16
+ --job-name=test --kill-on-bad-exit=1 python3 main.py \
17
+ --image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \
18
+ --prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \
19
+ --output_path "./results/Trump_Biden" --use_adain --use_reschedule \
20
+ --save_inter --use_lcm
21
+
Image-Morpher/utils/__init__.py ADDED
File without changes
Image-Morpher/utils/alpha_scheduler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import lpips
5
+
6
+ perceptual_loss = lpips.LPIPS()
7
+
8
+
9
+ def distance(img_a, img_b):
10
+ return perceptual_loss(img_a, img_b).item()
11
+ # return F.mse_loss(img_a, img_b).item()
12
+
13
+
14
+ class AlphaScheduler:
15
+ def __init__(self):
16
+ ...
17
+
18
+ def from_imgs(self, imgs):
19
+ self.__num_values = len(imgs)
20
+ self.__values = [0]
21
+ for i in range(self.__num_values - 1):
22
+ dis = distance(imgs[i], imgs[i + 1])
23
+ self.__values.append(dis)
24
+ self.__values[i + 1] += self.__values[i]
25
+ for i in range(self.__num_values):
26
+ self.__values[i] /= self.__values[-1]
27
+
28
+ def save(self, filename):
29
+ torch.save(torch.tensor(self.__values), filename)
30
+
31
+ def load(self, filename):
32
+ self.__values = torch.load(filename).tolist()
33
+ self.__num_values = len(self.__values)
34
+
35
+ def get_x(self, y):
36
+ assert y >= 0 and y <= 1
37
+ id = bisect.bisect_left(self.__values, y)
38
+ id -= 1
39
+ if id < 0:
40
+ id = 0
41
+ yl = self.__values[id]
42
+ yr = self.__values[id + 1]
43
+ xl = id * (1 / (self.__num_values - 1))
44
+ xr = (id + 1) * (1 / (self.__num_values - 1))
45
+ x = (y - yl) / (yr - yl) * (xr - xl) + xl
46
+ return x
47
+
48
+ def get_list(self, len=None):
49
+ if len is None:
50
+ len = self.__num_values
51
+
52
+ ys = torch.linspace(0, 1, len)
53
+ res = [self.get_x(y) for y in ys]
54
+ return res
Image-Morpher/utils/lora_utils.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils / lora_utils.py
2
+ from timeit import default_timer as timer
3
+ from datetime import timedelta
4
+ from PIL import Image
5
+ import os
6
+ import numpy as np
7
+ from einops import rearrange
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ import transformers
12
+ from accelerate import Accelerator
13
+ from accelerate.utils import set_seed
14
+ from packaging import version
15
+ from PIL import Image
16
+ import tqdm
17
+
18
+ from transformers import AutoTokenizer, PretrainedConfig
19
+
20
+ import diffusers
21
+ from diffusers import (
22
+ AutoencoderKL,
23
+ DDPMScheduler,
24
+ DiffusionPipeline,
25
+ DPMSolverMultistepScheduler,
26
+ StableDiffusionPipeline,
27
+ UNet2DConditionModel,
28
+ )
29
+ from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
30
+ from diffusers.models.attention_processor import (
31
+ AttnAddedKVProcessor,
32
+ AttnAddedKVProcessor2_0,
33
+ LoRAAttnAddedKVProcessor,
34
+ LoRAAttnProcessor,
35
+ LoRAAttnProcessor2_0,
36
+ SlicedAttnAddedKVProcessor,
37
+ )
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.utils import check_min_version
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+
42
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
+ check_min_version("0.17.0")
44
+
45
+
46
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
47
+ text_encoder_config = PretrainedConfig.from_pretrained(
48
+ pretrained_model_name_or_path,
49
+ subfolder="text_encoder",
50
+ revision=revision,
51
+ )
52
+ model_class = text_encoder_config.architectures[0]
53
+
54
+ if model_class == "CLIPTextModel":
55
+ from transformers import CLIPTextModel
56
+
57
+ return CLIPTextModel
58
+ elif model_class == "RobertaSeriesModelWithTransformation":
59
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
60
+
61
+ return RobertaSeriesModelWithTransformation
62
+ elif model_class == "T5EncoderModel":
63
+ from transformers import T5EncoderModel
64
+
65
+ return T5EncoderModel
66
+ else:
67
+ raise ValueError(f"{model_class} is not supported.")
68
+
69
+ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
70
+ if tokenizer_max_length is not None:
71
+ max_length = tokenizer_max_length
72
+ else:
73
+ max_length = tokenizer.model_max_length
74
+
75
+ text_inputs = tokenizer(
76
+ prompt,
77
+ truncation=True,
78
+ padding="max_length",
79
+ max_length=max_length,
80
+ return_tensors="pt",
81
+ )
82
+
83
+ return text_inputs
84
+
85
+ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
86
+ text_input_ids = input_ids.to(text_encoder.device)
87
+
88
+ if text_encoder_use_attention_mask:
89
+ attention_mask = attention_mask.to(text_encoder.device)
90
+ else:
91
+ attention_mask = None
92
+
93
+ prompt_embeds = text_encoder(
94
+ text_input_ids,
95
+ attention_mask=attention_mask,
96
+ )
97
+ prompt_embeds = prompt_embeds[0]
98
+
99
+ return prompt_embeds
100
+
101
+ # model_path: path of the model
102
+ # image: input image, have not been pre-processed
103
+ # save_lora_dir: the path to save the lora
104
+ # prompt: the user input prompt
105
+ # lora_steps: number of lora training step
106
+ # lora_lr: learning rate of lora training
107
+ # lora_rank: the rank of lora
108
+ def train_lora(image, prompt, save_lora_dir, model_path=None, tokenizer=None, text_encoder=None, vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
109
+ # initialize accelerator
110
+ accelerator = Accelerator(
111
+ gradient_accumulation_steps=1,
112
+ # mixed_precision='fp16'
113
+ )
114
+ set_seed(0)
115
+
116
+ # Load the tokenizer
117
+ if tokenizer is None:
118
+ tokenizer = AutoTokenizer.from_pretrained(
119
+ model_path,
120
+ subfolder="tokenizer",
121
+ revision=None,
122
+ use_fast=False,
123
+ )
124
+ # initialize the model
125
+ if noise_scheduler is None:
126
+ noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
127
+ if text_encoder is None:
128
+ text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
129
+ text_encoder = text_encoder_cls.from_pretrained(
130
+ model_path, subfolder="text_encoder", revision=None
131
+ )
132
+ if vae is None:
133
+ vae = AutoencoderKL.from_pretrained(
134
+ model_path, subfolder="vae", revision=None
135
+ )
136
+ if unet is None:
137
+ unet = UNet2DConditionModel.from_pretrained(
138
+ model_path, subfolder="unet", revision=None
139
+ )
140
+
141
+ # set device and dtype
142
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
143
+
144
+ vae.requires_grad_(False)
145
+ text_encoder.requires_grad_(False)
146
+ unet.requires_grad_(False)
147
+
148
+ unet.to(device)
149
+ vae.to(device)
150
+ text_encoder.to(device)
151
+
152
+ # initialize UNet LoRA
153
+ unet_lora_attn_procs = {}
154
+ for name, attn_processor in unet.attn_processors.items():
155
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
156
+ if name.startswith("mid_block"):
157
+ hidden_size = unet.config.block_out_channels[-1]
158
+ elif name.startswith("up_blocks"):
159
+ block_id = int(name[len("up_blocks.")])
160
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
161
+ elif name.startswith("down_blocks"):
162
+ block_id = int(name[len("down_blocks.")])
163
+ hidden_size = unet.config.block_out_channels[block_id]
164
+ else:
165
+ raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
166
+
167
+ if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
168
+ lora_attn_processor_class = LoRAAttnAddedKVProcessor
169
+ else:
170
+ lora_attn_processor_class = (
171
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
172
+ )
173
+ unet_lora_attn_procs[name] = lora_attn_processor_class(
174
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
175
+ )
176
+ unet.set_attn_processor(unet_lora_attn_procs)
177
+ unet_lora_layers = AttnProcsLayers(unet.attn_processors)
178
+
179
+ # Optimizer creation
180
+ params_to_optimize = (unet_lora_layers.parameters())
181
+ optimizer = torch.optim.AdamW(
182
+ params_to_optimize,
183
+ lr=lora_lr,
184
+ betas=(0.9, 0.999),
185
+ weight_decay=1e-2,
186
+ eps=1e-08,
187
+ )
188
+
189
+ lr_scheduler = get_scheduler(
190
+ "constant",
191
+ optimizer=optimizer,
192
+ num_warmup_steps=0,
193
+ num_training_steps=lora_steps,
194
+ num_cycles=1,
195
+ power=1.0,
196
+ )
197
+
198
+ # prepare accelerator
199
+ unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
200
+ optimizer = accelerator.prepare_optimizer(optimizer)
201
+ lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
202
+
203
+ # initialize text embeddings
204
+ with torch.no_grad():
205
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
206
+ text_embedding = encode_prompt(
207
+ text_encoder,
208
+ text_inputs.input_ids,
209
+ text_inputs.attention_mask,
210
+ text_encoder_use_attention_mask=False
211
+ )
212
+
213
+ if type(image) == np.ndarray:
214
+ image = Image.fromarray(image)
215
+
216
+ # initialize latent distribution
217
+ image_transforms = transforms.Compose(
218
+ [
219
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
220
+ # transforms.RandomCrop(512),
221
+ transforms.ToTensor(),
222
+ transforms.Normalize([0.5], [0.5]),
223
+ ]
224
+ )
225
+
226
+ image = image_transforms(image).to(device)
227
+ image = image.unsqueeze(dim=0)
228
+
229
+ latents_dist = vae.encode(image).latent_dist
230
+ for _ in progress.tqdm(range(lora_steps), desc="Training LoRA..."):
231
+ unet.train()
232
+ model_input = latents_dist.sample() * vae.config.scaling_factor
233
+ # Sample noise that we'll add to the latents
234
+ noise = torch.randn_like(model_input)
235
+ bsz, channels, height, width = model_input.shape
236
+ # Sample a random timestep for each image
237
+ timesteps = torch.randint(
238
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
239
+ )
240
+ timesteps = timesteps.long()
241
+
242
+ # Add noise to the model input according to the noise magnitude at each timestep
243
+ # (this is the forward diffusion process)
244
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
245
+
246
+ # Predict the noise residual
247
+ model_pred = unet(noisy_model_input, timesteps, text_embedding).sample
248
+
249
+ # Get the target for loss depending on the prediction type
250
+ if noise_scheduler.config.prediction_type == "epsilon":
251
+ target = noise
252
+ elif noise_scheduler.config.prediction_type == "v_prediction":
253
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
254
+ else:
255
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
256
+
257
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
258
+ accelerator.backward(loss)
259
+ optimizer.step()
260
+ lr_scheduler.step()
261
+ optimizer.zero_grad()
262
+
263
+ # save the trained lora
264
+ # unet = unet.to(torch.float32)
265
+ # vae = vae.to(torch.float32)
266
+ # text_encoder = text_encoder.to(torch.float32)
267
+
268
+ # unwrap_model is used to remove all special modules added when doing distributed training
269
+ # so here, there is no need to call unwrap_model
270
+ # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
271
+ LoraLoaderMixin.save_lora_weights(
272
+ save_directory=save_lora_dir,
273
+ unet_lora_layers=unet_lora_layers,
274
+ text_encoder_lora_layers=None,
275
+ weight_name=weight_name,
276
+ safe_serialization=safe_serialization
277
+ )
278
+
279
+ def load_lora(unet, lora_0, lora_1, alpha):
280
+ lora = {}
281
+ for key in lora_0:
282
+ lora[key] = (1 - alpha) * lora_0[key] + alpha * lora_1[key]
283
+ unet.load_attn_procs(lora)
284
+ return unet
Image-Morpher/utils/model_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/model_utils.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+
6
+ def calc_mean_std(feat, eps=1e-5):
7
+ # eps is a small value added to the variance to avoid divide-by-zero.
8
+ size = feat.size()
9
+
10
+ N, C = size[:2]
11
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
12
+ if len(size) == 3:
13
+ feat_std = feat_var.sqrt().view(N, C, 1)
14
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1)
15
+ else:
16
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
17
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
18
+ return feat_mean, feat_std
19
+
20
+
21
+ def get_img(img, resolution=512):
22
+ norm_mean = [0.5, 0.5, 0.5]
23
+ norm_std = [0.5, 0.5, 0.5]
24
+ transform = transforms.Compose([
25
+ transforms.Resize((resolution, resolution)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(norm_mean, norm_std)
28
+ ])
29
+ img = transform(img)
30
+ return img.unsqueeze(0)
31
+
32
+ @torch.no_grad()
33
+ def slerp(p0, p1, fract_mixing: float, adain=True):
34
+ r""" Copied from lunarring/latentblending
35
+ Helper function to correctly mix two random variables using spherical interpolation.
36
+ The function will always cast up to float64 for sake of extra 4.
37
+ Args:
38
+ p0:
39
+ First tensor for interpolation
40
+ p1:
41
+ Second tensor for interpolation
42
+ fract_mixing: float
43
+ Mixing coefficient of interval [0, 1].
44
+ 0 will return in p0
45
+ 1 will return in p1
46
+ 0.x will return a mix between both preserving angular velocity.
47
+ """
48
+ if p0.dtype == torch.float16:
49
+ recast_to = 'fp16'
50
+ else:
51
+ recast_to = 'fp32'
52
+
53
+ p0 = p0.double()
54
+ p1 = p1.double()
55
+
56
+ if adain:
57
+ mean1, std1 = calc_mean_std(p0)
58
+ mean2, std2 = calc_mean_std(p1)
59
+ mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing
60
+ std = std1 * (1 - fract_mixing) + std2 * fract_mixing
61
+
62
+ norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
63
+ epsilon = 1e-7
64
+ dot = torch.sum(p0 * p1) / norm
65
+ dot = dot.clamp(-1+epsilon, 1-epsilon)
66
+
67
+ theta_0 = torch.arccos(dot)
68
+ sin_theta_0 = torch.sin(theta_0)
69
+ theta_t = theta_0 * fract_mixing
70
+ s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
71
+ s1 = torch.sin(theta_t) / sin_theta_0
72
+ interp = p0*s0 + p1*s1
73
+
74
+ if adain:
75
+ interp = F.instance_norm(interp) * std + mean
76
+
77
+ if recast_to == 'fp16':
78
+ interp = interp.half()
79
+ elif recast_to == 'fp32':
80
+ interp = interp.float()
81
+
82
+ return interp
83
+
84
+
85
+ def do_replace_attn(key: str):
86
+ # return key.startswith('up_blocks.2') or key.startswith('up_blocks.3')
87
+ return key.startswith('up')