tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +7 -2
matryoshka.py
CHANGED
@@ -3830,6 +3830,9 @@ class MatryoshkaPipeline(
|
|
3830 |
new_config["sample_size"] = 64
|
3831 |
unet._internal_dict = FrozenDict(new_config)
|
3832 |
|
|
|
|
|
|
|
3833 |
self.register_modules(
|
3834 |
text_encoder=text_encoder,
|
3835 |
tokenizer=tokenizer,
|
@@ -3839,12 +3842,12 @@ class MatryoshkaPipeline(
|
|
3839 |
image_encoder=image_encoder,
|
3840 |
)
|
3841 |
self.register_to_config(nesting_level=nesting_level)
|
3842 |
-
if hasattr(unet, "nest_ratio"):
|
3843 |
-
scheduler.scales = unet.nest_ratio + [1]
|
3844 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3845 |
|
3846 |
def change_nesting_level(self, nesting_level: int):
|
3847 |
if nesting_level == 0:
|
|
|
|
|
3848 |
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3849 |
subfolder="unet/nesting_level_0").to(self.device)
|
3850 |
self.config.nesting_level = 0
|
@@ -3852,10 +3855,12 @@ class MatryoshkaPipeline(
|
|
3852 |
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3853 |
subfolder="unet/nesting_level_1").to(self.device)
|
3854 |
self.config.nesting_level = 1
|
|
|
3855 |
elif nesting_level == 2:
|
3856 |
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3857 |
subfolder="unet/nesting_level_2").to(self.device)
|
3858 |
self.config.nesting_level = 2
|
|
|
3859 |
else:
|
3860 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3861 |
|
|
|
3830 |
new_config["sample_size"] = 64
|
3831 |
unet._internal_dict = FrozenDict(new_config)
|
3832 |
|
3833 |
+
if hasattr(unet, "nest_ratio"):
|
3834 |
+
scheduler.scales = unet.nest_ratio + [1]
|
3835 |
+
|
3836 |
self.register_modules(
|
3837 |
text_encoder=text_encoder,
|
3838 |
tokenizer=tokenizer,
|
|
|
3842 |
image_encoder=image_encoder,
|
3843 |
)
|
3844 |
self.register_to_config(nesting_level=nesting_level)
|
|
|
|
|
3845 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3846 |
|
3847 |
def change_nesting_level(self, nesting_level: int):
|
3848 |
if nesting_level == 0:
|
3849 |
+
if hasattr(self.unet, "nest_ratio"):
|
3850 |
+
self.scheduler.scales = None
|
3851 |
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3852 |
subfolder="unet/nesting_level_0").to(self.device)
|
3853 |
self.config.nesting_level = 0
|
|
|
3855 |
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3856 |
subfolder="unet/nesting_level_1").to(self.device)
|
3857 |
self.config.nesting_level = 1
|
3858 |
+
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3859 |
elif nesting_level == 2:
|
3860 |
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3861 |
subfolder="unet/nesting_level_2").to(self.device)
|
3862 |
self.config.nesting_level = 2
|
3863 |
+
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3864 |
else:
|
3865 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3866 |
|