Added freeu functions to unet
Browse files- pipeline.py +26 -0
pipeline.py
CHANGED
@@ -864,6 +864,32 @@ class StableDiffusionLongPromptWeightingPipeline(
|
|
864 |
latents = init_latents
|
865 |
return latents, init_latents_orig, noise
|
866 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
867 |
@torch.no_grad()
|
868 |
def __call__(
|
869 |
self,
|
|
|
864 |
latents = init_latents
|
865 |
return latents, init_latents_orig, noise
|
866 |
|
867 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
868 |
+
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
869 |
+
|
870 |
+
The suffixes after the scaling factors represent the stages where they are being applied.
|
871 |
+
|
872 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
873 |
+
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
874 |
+
|
875 |
+
Args:
|
876 |
+
s1 (`float`):
|
877 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
878 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
879 |
+
s2 (`float`):
|
880 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
881 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
882 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
883 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
884 |
+
"""
|
885 |
+
if not hasattr(self, "unet"):
|
886 |
+
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
887 |
+
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
888 |
+
|
889 |
+
def disable_freeu(self):
|
890 |
+
"""Disables the FreeU mechanism if enabled."""
|
891 |
+
self.unet.disable_freeu()
|
892 |
+
|
893 |
@torch.no_grad()
|
894 |
def __call__(
|
895 |
self,
|