AlanB commited on
Commit
cb9f0d2
·
1 Parent(s): 37940d4

Added freeu functions to unet

Browse files
Files changed (1) hide show
  1. pipeline.py +26 -0
pipeline.py CHANGED
@@ -864,6 +864,32 @@ class StableDiffusionLongPromptWeightingPipeline(
864
  latents = init_latents
865
  return latents, init_latents_orig, noise
866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
  @torch.no_grad()
868
  def __call__(
869
  self,
 
864
  latents = init_latents
865
  return latents, init_latents_orig, noise
866
 
867
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
868
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
869
+
870
+ The suffixes after the scaling factors represent the stages where they are being applied.
871
+
872
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
873
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
874
+
875
+ Args:
876
+ s1 (`float`):
877
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
878
+ mitigate "oversmoothing effect" in the enhanced denoising process.
879
+ s2 (`float`):
880
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
881
+ mitigate "oversmoothing effect" in the enhanced denoising process.
882
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
883
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
884
+ """
885
+ if not hasattr(self, "unet"):
886
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
887
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
888
+
889
+ def disable_freeu(self):
890
+ """Disables the FreeU mechanism if enabled."""
891
+ self.unet.disable_freeu()
892
+
893
  @torch.no_grad()
894
  def __call__(
895
  self,