tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +39 -26
matryoshka.py
CHANGED
@@ -19,8 +19,8 @@
|
|
19 |
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
20 |
|
21 |
|
22 |
-
import inspect
|
23 |
import gc
|
|
|
24 |
import math
|
25 |
from dataclasses import dataclass
|
26 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
@@ -633,7 +633,10 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
633 |
# 4. Clip or threshold "predicted x_0"
|
634 |
if self.config.thresholding:
|
635 |
if len(model_output) > 1:
|
636 |
-
pred_original_sample = [
|
|
|
|
|
|
|
637 |
else:
|
638 |
pred_original_sample = self._threshold_sample(pred_original_sample)
|
639 |
elif self.config.clip_sample:
|
@@ -3777,14 +3780,17 @@ class MatryoshkaPipeline(
|
|
3777 |
super().__init__()
|
3778 |
|
3779 |
if nesting_level == 0:
|
3780 |
-
unet = MatryoshkaUNet2DConditionModel.from_pretrained(
|
3781 |
-
|
|
|
3782 |
elif nesting_level == 1:
|
3783 |
-
unet = NestedUNet2DConditionModel.from_pretrained(
|
3784 |
-
|
|
|
3785 |
elif nesting_level == 2:
|
3786 |
-
unet = NestedUNet2DConditionModel.from_pretrained(
|
3787 |
-
|
|
|
3788 |
else:
|
3789 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3790 |
|
@@ -3854,17 +3860,20 @@ class MatryoshkaPipeline(
|
|
3854 |
if nesting_level == 0:
|
3855 |
if hasattr(self.unet, "nest_ratio"):
|
3856 |
self.scheduler.scales = None
|
3857 |
-
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained(
|
3858 |
-
|
|
|
3859 |
self.config.nesting_level = 0
|
3860 |
elif nesting_level == 1:
|
3861 |
-
self.unet = NestedUNet2DConditionModel.from_pretrained(
|
3862 |
-
|
|
|
3863 |
self.config.nesting_level = 1
|
3864 |
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3865 |
elif nesting_level == 2:
|
3866 |
-
self.unet = NestedUNet2DConditionModel.from_pretrained(
|
3867 |
-
|
|
|
3868 |
self.config.nesting_level = 2
|
3869 |
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3870 |
else:
|
@@ -4030,7 +4039,9 @@ class MatryoshkaPipeline(
|
|
4030 |
prompt_attention_mask = torch.cat(
|
4031 |
[
|
4032 |
prompt_attention_mask,
|
4033 |
-
torch.zeros(
|
|
|
|
|
4034 |
],
|
4035 |
dim=1,
|
4036 |
)
|
@@ -4042,7 +4053,12 @@ class MatryoshkaPipeline(
|
|
4042 |
negative_prompt_attention_mask = torch.cat(
|
4043 |
[
|
4044 |
negative_prompt_attention_mask,
|
4045 |
-
torch.zeros(
|
|
|
|
|
|
|
|
|
|
|
4046 |
],
|
4047 |
dim=1,
|
4048 |
)
|
@@ -4533,7 +4549,6 @@ class MatryoshkaPipeline(
|
|
4533 |
self.do_classifier_free_guidance,
|
4534 |
)
|
4535 |
|
4536 |
-
|
4537 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
4538 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
4539 |
# 4. Prepare timesteps
|
@@ -4656,17 +4671,14 @@ class MatryoshkaPipeline(
|
|
4656 |
image = latents
|
4657 |
|
4658 |
if self.scheduler.scales is not None:
|
4659 |
-
scales = [
|
4660 |
-
image[i].size(-1) / image[-1].size(-1)
|
4661 |
-
for i in range(len(image))
|
4662 |
-
]
|
4663 |
for i, (img, scale) in enumerate(zip(image, scales)):
|
4664 |
img = torch.clip(img * scale, -1, 1)
|
4665 |
-
img = torch.clamp(img * 0.5 + 0.5, min=0, max=1).cpu()
|
4666 |
-
img = img.squeeze(0).permute(1, 2, 0).numpy()
|
4667 |
# img = self.image_processor.pt_to_numpy(img)
|
4668 |
-
image[i] = numpy_to_pil(img)[0]
|
4669 |
-
|
4670 |
else:
|
4671 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
4672 |
|
@@ -4678,6 +4690,7 @@ class MatryoshkaPipeline(
|
|
4678 |
|
4679 |
return MatryoshkaPipelineOutput(images=image)
|
4680 |
|
|
|
4681 |
def numpy_to_pil(images: np.ndarray) -> List[Image.Image]:
|
4682 |
"""
|
4683 |
Convert a numpy image or a batch of images to a PIL image.
|
@@ -4691,4 +4704,4 @@ def numpy_to_pil(images: np.ndarray) -> List[Image.Image]:
|
|
4691 |
else:
|
4692 |
pil_images = [Image.fromarray(image) for image in images]
|
4693 |
|
4694 |
-
return pil_images
|
|
|
19 |
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
20 |
|
21 |
|
|
|
22 |
import gc
|
23 |
+
import inspect
|
24 |
import math
|
25 |
from dataclasses import dataclass
|
26 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
633 |
# 4. Clip or threshold "predicted x_0"
|
634 |
if self.config.thresholding:
|
635 |
if len(model_output) > 1:
|
636 |
+
pred_original_sample = [
|
637 |
+
self._threshold_sample(p_o_s * scale) / scale
|
638 |
+
for p_o_s, scale in zip(pred_original_sample, self.scales)
|
639 |
+
]
|
640 |
else:
|
641 |
pred_original_sample = self._threshold_sample(pred_original_sample)
|
642 |
elif self.config.clip_sample:
|
|
|
3780 |
super().__init__()
|
3781 |
|
3782 |
if nesting_level == 0:
|
3783 |
+
unet = MatryoshkaUNet2DConditionModel.from_pretrained(
|
3784 |
+
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0"
|
3785 |
+
)
|
3786 |
elif nesting_level == 1:
|
3787 |
+
unet = NestedUNet2DConditionModel.from_pretrained(
|
3788 |
+
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1"
|
3789 |
+
)
|
3790 |
elif nesting_level == 2:
|
3791 |
+
unet = NestedUNet2DConditionModel.from_pretrained(
|
3792 |
+
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
|
3793 |
+
)
|
3794 |
else:
|
3795 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3796 |
|
|
|
3860 |
if nesting_level == 0:
|
3861 |
if hasattr(self.unet, "nest_ratio"):
|
3862 |
self.scheduler.scales = None
|
3863 |
+
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained(
|
3864 |
+
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0"
|
3865 |
+
).to(self.device)
|
3866 |
self.config.nesting_level = 0
|
3867 |
elif nesting_level == 1:
|
3868 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained(
|
3869 |
+
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1"
|
3870 |
+
).to(self.device)
|
3871 |
self.config.nesting_level = 1
|
3872 |
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3873 |
elif nesting_level == 2:
|
3874 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained(
|
3875 |
+
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
|
3876 |
+
).to(self.device)
|
3877 |
self.config.nesting_level = 2
|
3878 |
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3879 |
else:
|
|
|
4039 |
prompt_attention_mask = torch.cat(
|
4040 |
[
|
4041 |
prompt_attention_mask,
|
4042 |
+
torch.zeros(
|
4043 |
+
batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device
|
4044 |
+
),
|
4045 |
],
|
4046 |
dim=1,
|
4047 |
)
|
|
|
4053 |
negative_prompt_attention_mask = torch.cat(
|
4054 |
[
|
4055 |
negative_prompt_attention_mask,
|
4056 |
+
torch.zeros(
|
4057 |
+
batch_size,
|
4058 |
+
max_len - len(negative_prompt_attention_mask[0]),
|
4059 |
+
dtype=torch.long,
|
4060 |
+
device=device,
|
4061 |
+
),
|
4062 |
],
|
4063 |
dim=1,
|
4064 |
)
|
|
|
4549 |
self.do_classifier_free_guidance,
|
4550 |
)
|
4551 |
|
|
|
4552 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
4553 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
4554 |
# 4. Prepare timesteps
|
|
|
4671 |
image = latents
|
4672 |
|
4673 |
if self.scheduler.scales is not None:
|
4674 |
+
scales = [image[i].size(-1) / image[-1].size(-1) for i in range(len(image))]
|
|
|
|
|
|
|
4675 |
for i, (img, scale) in enumerate(zip(image, scales)):
|
4676 |
img = torch.clip(img * scale, -1, 1)
|
4677 |
+
# img = torch.clamp(img * 0.5 + 0.5, min=0, max=1).cpu()
|
4678 |
+
# img = img.squeeze(0).permute(1, 2, 0).numpy()
|
4679 |
# img = self.image_processor.pt_to_numpy(img)
|
4680 |
+
# image[i] = numpy_to_pil(img)[0]
|
4681 |
+
image[i] = self.image_processor.postprocess(img, output_type=output_type)[0]
|
4682 |
else:
|
4683 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
4684 |
|
|
|
4690 |
|
4691 |
return MatryoshkaPipelineOutput(images=image)
|
4692 |
|
4693 |
+
|
4694 |
def numpy_to_pil(images: np.ndarray) -> List[Image.Image]:
|
4695 |
"""
|
4696 |
Convert a numpy image or a batch of images to a PIL image.
|
|
|
4704 |
else:
|
4705 |
pil_images = [Image.fromarray(image) for image in images]
|
4706 |
|
4707 |
+
return pil_images
|