tolgacangoz commited on
Commit
a93c410
·
verified ·
1 Parent(s): 6bb2088

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +56 -68
matryoshka.py CHANGED
@@ -1,4 +1,23 @@
1
- # #TODO Licensed under the Apache License, Version 2.0 or MIT?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import inspect
4
  import math
@@ -613,14 +632,14 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
613
  # 4. Clip or threshold "predicted x_0"
614
  if self.config.thresholding:
615
  if len(model_output) > 1:
616
- pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]
617
  else:
618
  pred_original_sample = self._threshold_sample(pred_original_sample)
619
  elif self.config.clip_sample:
620
  if len(model_output) > 1:
621
  pred_original_sample = [
622
- p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
623
- for p_o_s in pred_original_sample
624
  ]
625
  else:
626
  pred_original_sample = pred_original_sample.clamp(
@@ -3707,7 +3726,7 @@ class MatryoshkaPipeline(
3707
  FromSingleFileMixin,
3708
  ):
3709
  r"""
3710
- Pipeline for text-to-image generation using Stable Diffusion.
3711
 
3712
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
3713
  implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -3720,21 +3739,17 @@ class MatryoshkaPipeline(
3720
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
3721
 
3722
  Args:
3723
- text_encoder ([`~transformers.CLIPTextModel`]):
3724
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
3725
- tokenizer ([`~transformers.CLIPTokenizer`]):
3726
- A `CLIPTokenizer` to tokenize text.
3727
  unet ([`MatryoshkaUNet2DConditionModel`]):
3728
  A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents.
3729
  scheduler ([`SchedulerMixin`]):
3730
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
3731
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
3732
- safety_checker ([`StableDiffusionSafetyChecker`]):
3733
- Classification module that estimates whether generated images could be considered offensive or harmful.
3734
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
3735
- about a model's potential harms.
3736
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
3737
- A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
3738
  """
3739
 
3740
  model_cpu_offload_seq = "text_encoder->image_encoder->unet"
@@ -3755,6 +3770,18 @@ class MatryoshkaPipeline(
3755
  ):
3756
  super().__init__()
3757
 
 
 
 
 
 
 
 
 
 
 
 
 
3758
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3759
  deprecation_message = (
3760
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@@ -3782,10 +3809,10 @@ class MatryoshkaPipeline(
3782
  new_config["clip_sample"] = False
3783
  scheduler._internal_dict = FrozenDict(new_config)
3784
 
3785
- is_unet_version_less_0_9_0 = hasattr(unet[0].config, "_diffusers_version") and version.parse(
3786
- version.parse(unet[0].config._diffusers_version).base_version
3787
  ) < version.parse("0.9.0.dev0")
3788
- is_unet_sample_size_less_64 = hasattr(unet[0].config, "sample_size") and unet[0].config.sample_size < 64
3789
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
3790
  deprecation_message = (
3791
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -3803,16 +3830,6 @@ class MatryoshkaPipeline(
3803
  new_config["sample_size"] = 64
3804
  unet._internal_dict = FrozenDict(new_config)
3805
 
3806
- if nesting_level == 0:
3807
- unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3808
- subfolder="unet/nesting_level_0")
3809
- elif nesting_level == 1:
3810
- unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3811
- subfolder="unet/nesting_level_1")
3812
- elif nesting_level == 2:
3813
- unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3814
- subfolder="unet/nesting_level_2")
3815
-
3816
  self.register_modules(
3817
  text_encoder=text_encoder,
3818
  tokenizer=tokenizer,
@@ -3825,38 +3842,6 @@ class MatryoshkaPipeline(
3825
  scheduler.scales = unet.nest_ratio + [1]
3826
  self.image_processor = VaeImageProcessor(do_resize=False)
3827
 
3828
- def _encode_prompt(
3829
- self,
3830
- prompt,
3831
- device,
3832
- num_images_per_prompt,
3833
- do_classifier_free_guidance,
3834
- negative_prompt=None,
3835
- prompt_embeds: Optional[torch.Tensor] = None,
3836
- negative_prompt_embeds: Optional[torch.Tensor] = None,
3837
- lora_scale: Optional[float] = None,
3838
- **kwargs,
3839
- ):
3840
- deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
3841
- deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
3842
-
3843
- prompt_embeds_tuple = self.encode_prompt(
3844
- prompt=prompt,
3845
- device=device,
3846
- num_images_per_prompt=num_images_per_prompt,
3847
- do_classifier_free_guidance=do_classifier_free_guidance,
3848
- negative_prompt=negative_prompt,
3849
- prompt_embeds=prompt_embeds,
3850
- negative_prompt_embeds=negative_prompt_embeds,
3851
- lora_scale=lora_scale,
3852
- **kwargs,
3853
- )
3854
-
3855
- # concatenate for backwards comp
3856
- prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
3857
-
3858
- return prompt_embeds
3859
-
3860
  def encode_prompt(
3861
  self,
3862
  prompt,
@@ -3935,7 +3920,7 @@ class MatryoshkaPipeline(
3935
  untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
3936
  )
3937
  logger.warning(
3938
- "The following part of your input was truncated because CLIP can only handle sequences up to"
3939
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
3940
  )
3941
 
@@ -4414,8 +4399,8 @@ class MatryoshkaPipeline(
4414
  Examples:
4415
 
4416
  Returns:
4417
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
4418
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
4419
  otherwise a `tuple` is returned where the first element is a list with the generated images and the
4420
  second element is a list of `bool`s indicating whether the corresponding generated image contains
4421
  "not-safe-for-work" (nsfw) content.
@@ -4522,10 +4507,11 @@ class MatryoshkaPipeline(
4522
  timesteps, num_inference_steps = retrieve_timesteps(
4523
  self.scheduler, num_inference_steps, device, timesteps, sigmas
4524
  )
4525
- timesteps = timesteps[:-1]
4526
  else:
4527
  timesteps = self.scheduler.timesteps
4528
 
 
 
4529
  # 5. Prepare latent variables
4530
  num_channels_latents = self.unet.config.in_channels
4531
  latents = self.prepare_latents(
@@ -4637,9 +4623,11 @@ class MatryoshkaPipeline(
4637
  image = latents
4638
 
4639
  if self.scheduler.scales is not None:
4640
- image = image[0]
4641
-
4642
- image = self.image_processor.postprocess(image, output_type=output_type)
 
 
4643
 
4644
  # Offload all models
4645
  self.maybe_free_model_hooks()
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111).
16
+ # Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly
17
+ # Code: https://github.com/apple/ml-mdm with MIT license
18
+ #
19
+ # Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
20
+
21
 
22
  import inspect
23
  import math
 
632
  # 4. Clip or threshold "predicted x_0"
633
  if self.config.thresholding:
634
  if len(model_output) > 1:
635
+ pred_original_sample = [self._threshold_sample(p_o_s * scale) / scale for p_o_s, scale in zip(pred_original_sample, self.scales)]
636
  else:
637
  pred_original_sample = self._threshold_sample(pred_original_sample)
638
  elif self.config.clip_sample:
639
  if len(model_output) > 1:
640
  pred_original_sample = [
641
+ (p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
642
+ for p_o_s, scale in zip(pred_original_sample, self.scales)
643
  ]
644
  else:
645
  pred_original_sample = pred_original_sample.clamp(
 
3726
  FromSingleFileMixin,
3727
  ):
3728
  r"""
3729
+ Pipeline for text-to-image generation using Matryoshka Diffusion Models.
3730
 
3731
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
3732
  implemented for all pipelines (downloading, saving, running on a particular device, etc.).
 
3739
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
3740
 
3741
  Args:
3742
+ text_encoder ([`~transformers.T5EncoderModel`]):
3743
+ Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)).
3744
+ tokenizer ([`~transformers.T5Tokenizer`]):
3745
+ A `T5Tokenizer` to tokenize text.
3746
  unet ([`MatryoshkaUNet2DConditionModel`]):
3747
  A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents.
3748
  scheduler ([`SchedulerMixin`]):
3749
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
3750
+ [`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md.
3751
+ feature_extractor ([`~transformers.<AnImageProcessor>`]):
3752
+ A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
 
 
 
 
3753
  """
3754
 
3755
  model_cpu_offload_seq = "text_encoder->image_encoder->unet"
 
3770
  ):
3771
  super().__init__()
3772
 
3773
+ if nesting_level == 0:
3774
+ unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3775
+ subfolder="unet/nesting_level_0")
3776
+ elif nesting_level == 1:
3777
+ unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3778
+ subfolder="unet/nesting_level_1")
3779
+ elif nesting_level == 2:
3780
+ unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3781
+ subfolder="unet/nesting_level_2")
3782
+ else:
3783
+ raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3784
+
3785
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3786
  deprecation_message = (
3787
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
 
3809
  new_config["clip_sample"] = False
3810
  scheduler._internal_dict = FrozenDict(new_config)
3811
 
3812
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
3813
+ version.parse(unet.config._diffusers_version).base_version
3814
  ) < version.parse("0.9.0.dev0")
3815
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
3816
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
3817
  deprecation_message = (
3818
  "The configuration file of the unet has set the default `sample_size` to smaller than"
 
3830
  new_config["sample_size"] = 64
3831
  unet._internal_dict = FrozenDict(new_config)
3832
 
 
 
 
 
 
 
 
 
 
 
3833
  self.register_modules(
3834
  text_encoder=text_encoder,
3835
  tokenizer=tokenizer,
 
3842
  scheduler.scales = unet.nest_ratio + [1]
3843
  self.image_processor = VaeImageProcessor(do_resize=False)
3844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3845
  def encode_prompt(
3846
  self,
3847
  prompt,
 
3920
  untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
3921
  )
3922
  logger.warning(
3923
+ "The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to"
3924
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
3925
  )
3926
 
 
4399
  Examples:
4400
 
4401
  Returns:
4402
+ [`~MatryoshkaPipelineOutput`] or `tuple`:
4403
+ If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned,
4404
  otherwise a `tuple` is returned where the first element is a list with the generated images and the
4405
  second element is a list of `bool`s indicating whether the corresponding generated image contains
4406
  "not-safe-for-work" (nsfw) content.
 
4507
  timesteps, num_inference_steps = retrieve_timesteps(
4508
  self.scheduler, num_inference_steps, device, timesteps, sigmas
4509
  )
 
4510
  else:
4511
  timesteps = self.scheduler.timesteps
4512
 
4513
+ timesteps = timesteps[:-1]
4514
+
4515
  # 5. Prepare latent variables
4516
  num_channels_latents = self.unet.config.in_channels
4517
  latents = self.prepare_latents(
 
4623
  image = latents
4624
 
4625
  if self.scheduler.scales is not None:
4626
+ for i in range(len(image)):
4627
+ image[i] = image[i] * self.scheduler.scales[i]
4628
+ image[i] = self.image_processor.postprocess(image[i], output_type=output_type)
4629
+ else:
4630
+ image = self.image_processor.postprocess(image, output_type=output_type)
4631
 
4632
  # Offload all models
4633
  self.maybe_free_model_hooks()