tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +56 -68
matryoshka.py
CHANGED
@@ -1,4 +1,23 @@
|
|
1 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import inspect
|
4 |
import math
|
@@ -613,14 +632,14 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
613 |
# 4. Clip or threshold "predicted x_0"
|
614 |
if self.config.thresholding:
|
615 |
if len(model_output) > 1:
|
616 |
-
pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]
|
617 |
else:
|
618 |
pred_original_sample = self._threshold_sample(pred_original_sample)
|
619 |
elif self.config.clip_sample:
|
620 |
if len(model_output) > 1:
|
621 |
pred_original_sample = [
|
622 |
-
p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
|
623 |
-
for p_o_s in pred_original_sample
|
624 |
]
|
625 |
else:
|
626 |
pred_original_sample = pred_original_sample.clamp(
|
@@ -3707,7 +3726,7 @@ class MatryoshkaPipeline(
|
|
3707 |
FromSingleFileMixin,
|
3708 |
):
|
3709 |
r"""
|
3710 |
-
Pipeline for text-to-image generation using
|
3711 |
|
3712 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
3713 |
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
@@ -3720,21 +3739,17 @@ class MatryoshkaPipeline(
|
|
3720 |
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
3721 |
|
3722 |
Args:
|
3723 |
-
text_encoder ([`~transformers.
|
3724 |
-
Frozen text-encoder ([
|
3725 |
-
tokenizer ([`~transformers.
|
3726 |
-
A `
|
3727 |
unet ([`MatryoshkaUNet2DConditionModel`]):
|
3728 |
A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents.
|
3729 |
scheduler ([`SchedulerMixin`]):
|
3730 |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
3731 |
-
[`
|
3732 |
-
|
3733 |
-
|
3734 |
-
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
3735 |
-
about a model's potential harms.
|
3736 |
-
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
3737 |
-
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
3738 |
"""
|
3739 |
|
3740 |
model_cpu_offload_seq = "text_encoder->image_encoder->unet"
|
@@ -3755,6 +3770,18 @@ class MatryoshkaPipeline(
|
|
3755 |
):
|
3756 |
super().__init__()
|
3757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3758 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3759 |
deprecation_message = (
|
3760 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
@@ -3782,10 +3809,10 @@ class MatryoshkaPipeline(
|
|
3782 |
new_config["clip_sample"] = False
|
3783 |
scheduler._internal_dict = FrozenDict(new_config)
|
3784 |
|
3785 |
-
is_unet_version_less_0_9_0 = hasattr(unet
|
3786 |
-
version.parse(unet
|
3787 |
) < version.parse("0.9.0.dev0")
|
3788 |
-
is_unet_sample_size_less_64 = hasattr(unet
|
3789 |
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
3790 |
deprecation_message = (
|
3791 |
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
@@ -3803,16 +3830,6 @@ class MatryoshkaPipeline(
|
|
3803 |
new_config["sample_size"] = 64
|
3804 |
unet._internal_dict = FrozenDict(new_config)
|
3805 |
|
3806 |
-
if nesting_level == 0:
|
3807 |
-
unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3808 |
-
subfolder="unet/nesting_level_0")
|
3809 |
-
elif nesting_level == 1:
|
3810 |
-
unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3811 |
-
subfolder="unet/nesting_level_1")
|
3812 |
-
elif nesting_level == 2:
|
3813 |
-
unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3814 |
-
subfolder="unet/nesting_level_2")
|
3815 |
-
|
3816 |
self.register_modules(
|
3817 |
text_encoder=text_encoder,
|
3818 |
tokenizer=tokenizer,
|
@@ -3825,38 +3842,6 @@ class MatryoshkaPipeline(
|
|
3825 |
scheduler.scales = unet.nest_ratio + [1]
|
3826 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3827 |
|
3828 |
-
def _encode_prompt(
|
3829 |
-
self,
|
3830 |
-
prompt,
|
3831 |
-
device,
|
3832 |
-
num_images_per_prompt,
|
3833 |
-
do_classifier_free_guidance,
|
3834 |
-
negative_prompt=None,
|
3835 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
3836 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
3837 |
-
lora_scale: Optional[float] = None,
|
3838 |
-
**kwargs,
|
3839 |
-
):
|
3840 |
-
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
3841 |
-
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
3842 |
-
|
3843 |
-
prompt_embeds_tuple = self.encode_prompt(
|
3844 |
-
prompt=prompt,
|
3845 |
-
device=device,
|
3846 |
-
num_images_per_prompt=num_images_per_prompt,
|
3847 |
-
do_classifier_free_guidance=do_classifier_free_guidance,
|
3848 |
-
negative_prompt=negative_prompt,
|
3849 |
-
prompt_embeds=prompt_embeds,
|
3850 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
3851 |
-
lora_scale=lora_scale,
|
3852 |
-
**kwargs,
|
3853 |
-
)
|
3854 |
-
|
3855 |
-
# concatenate for backwards comp
|
3856 |
-
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
3857 |
-
|
3858 |
-
return prompt_embeds
|
3859 |
-
|
3860 |
def encode_prompt(
|
3861 |
self,
|
3862 |
prompt,
|
@@ -3935,7 +3920,7 @@ class MatryoshkaPipeline(
|
|
3935 |
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
3936 |
)
|
3937 |
logger.warning(
|
3938 |
-
"The following part of your input was truncated because
|
3939 |
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
3940 |
)
|
3941 |
|
@@ -4414,8 +4399,8 @@ class MatryoshkaPipeline(
|
|
4414 |
Examples:
|
4415 |
|
4416 |
Returns:
|
4417 |
-
[`~
|
4418 |
-
If `return_dict` is `True`, [`~
|
4419 |
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
4420 |
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
4421 |
"not-safe-for-work" (nsfw) content.
|
@@ -4522,10 +4507,11 @@ class MatryoshkaPipeline(
|
|
4522 |
timesteps, num_inference_steps = retrieve_timesteps(
|
4523 |
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
4524 |
)
|
4525 |
-
timesteps = timesteps[:-1]
|
4526 |
else:
|
4527 |
timesteps = self.scheduler.timesteps
|
4528 |
|
|
|
|
|
4529 |
# 5. Prepare latent variables
|
4530 |
num_channels_latents = self.unet.config.in_channels
|
4531 |
latents = self.prepare_latents(
|
@@ -4637,9 +4623,11 @@ class MatryoshkaPipeline(
|
|
4637 |
image = latents
|
4638 |
|
4639 |
if self.scheduler.scales is not None:
|
4640 |
-
|
4641 |
-
|
4642 |
-
|
|
|
|
|
4643 |
|
4644 |
# Offload all models
|
4645 |
self.maybe_free_model_hooks()
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111).
|
16 |
+
# Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly
|
17 |
+
# Code: https://github.com/apple/ml-mdm with MIT license
|
18 |
+
#
|
19 |
+
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
20 |
+
|
21 |
|
22 |
import inspect
|
23 |
import math
|
|
|
632 |
# 4. Clip or threshold "predicted x_0"
|
633 |
if self.config.thresholding:
|
634 |
if len(model_output) > 1:
|
635 |
+
pred_original_sample = [self._threshold_sample(p_o_s * scale) / scale for p_o_s, scale in zip(pred_original_sample, self.scales)]
|
636 |
else:
|
637 |
pred_original_sample = self._threshold_sample(pred_original_sample)
|
638 |
elif self.config.clip_sample:
|
639 |
if len(model_output) > 1:
|
640 |
pred_original_sample = [
|
641 |
+
(p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
|
642 |
+
for p_o_s, scale in zip(pred_original_sample, self.scales)
|
643 |
]
|
644 |
else:
|
645 |
pred_original_sample = pred_original_sample.clamp(
|
|
|
3726 |
FromSingleFileMixin,
|
3727 |
):
|
3728 |
r"""
|
3729 |
+
Pipeline for text-to-image generation using Matryoshka Diffusion Models.
|
3730 |
|
3731 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
3732 |
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
|
|
3739 |
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
3740 |
|
3741 |
Args:
|
3742 |
+
text_encoder ([`~transformers.T5EncoderModel`]):
|
3743 |
+
Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)).
|
3744 |
+
tokenizer ([`~transformers.T5Tokenizer`]):
|
3745 |
+
A `T5Tokenizer` to tokenize text.
|
3746 |
unet ([`MatryoshkaUNet2DConditionModel`]):
|
3747 |
A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents.
|
3748 |
scheduler ([`SchedulerMixin`]):
|
3749 |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
3750 |
+
[`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md.
|
3751 |
+
feature_extractor ([`~transformers.<AnImageProcessor>`]):
|
3752 |
+
A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
|
|
|
|
|
|
|
|
3753 |
"""
|
3754 |
|
3755 |
model_cpu_offload_seq = "text_encoder->image_encoder->unet"
|
|
|
3770 |
):
|
3771 |
super().__init__()
|
3772 |
|
3773 |
+
if nesting_level == 0:
|
3774 |
+
unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3775 |
+
subfolder="unet/nesting_level_0")
|
3776 |
+
elif nesting_level == 1:
|
3777 |
+
unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3778 |
+
subfolder="unet/nesting_level_1")
|
3779 |
+
elif nesting_level == 2:
|
3780 |
+
unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3781 |
+
subfolder="unet/nesting_level_2")
|
3782 |
+
else:
|
3783 |
+
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3784 |
+
|
3785 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3786 |
deprecation_message = (
|
3787 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
|
|
3809 |
new_config["clip_sample"] = False
|
3810 |
scheduler._internal_dict = FrozenDict(new_config)
|
3811 |
|
3812 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
3813 |
+
version.parse(unet.config._diffusers_version).base_version
|
3814 |
) < version.parse("0.9.0.dev0")
|
3815 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
3816 |
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
3817 |
deprecation_message = (
|
3818 |
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
|
|
3830 |
new_config["sample_size"] = 64
|
3831 |
unet._internal_dict = FrozenDict(new_config)
|
3832 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3833 |
self.register_modules(
|
3834 |
text_encoder=text_encoder,
|
3835 |
tokenizer=tokenizer,
|
|
|
3842 |
scheduler.scales = unet.nest_ratio + [1]
|
3843 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3844 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3845 |
def encode_prompt(
|
3846 |
self,
|
3847 |
prompt,
|
|
|
3920 |
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
3921 |
)
|
3922 |
logger.warning(
|
3923 |
+
"The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to"
|
3924 |
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
3925 |
)
|
3926 |
|
|
|
4399 |
Examples:
|
4400 |
|
4401 |
Returns:
|
4402 |
+
[`~MatryoshkaPipelineOutput`] or `tuple`:
|
4403 |
+
If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned,
|
4404 |
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
4405 |
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
4406 |
"not-safe-for-work" (nsfw) content.
|
|
|
4507 |
timesteps, num_inference_steps = retrieve_timesteps(
|
4508 |
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
4509 |
)
|
|
|
4510 |
else:
|
4511 |
timesteps = self.scheduler.timesteps
|
4512 |
|
4513 |
+
timesteps = timesteps[:-1]
|
4514 |
+
|
4515 |
# 5. Prepare latent variables
|
4516 |
num_channels_latents = self.unet.config.in_channels
|
4517 |
latents = self.prepare_latents(
|
|
|
4623 |
image = latents
|
4624 |
|
4625 |
if self.scheduler.scales is not None:
|
4626 |
+
for i in range(len(image)):
|
4627 |
+
image[i] = image[i] * self.scheduler.scales[i]
|
4628 |
+
image[i] = self.image_processor.postprocess(image[i], output_type=output_type)
|
4629 |
+
else:
|
4630 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
4631 |
|
4632 |
# Offload all models
|
4633 |
self.maybe_free_model_hooks()
|