tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +17 -2
matryoshka.py
CHANGED
@@ -3782,8 +3782,6 @@ class MatryoshkaPipeline(
|
|
3782 |
else:
|
3783 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3784 |
|
3785 |
-
unet = unet.to(self.device)
|
3786 |
-
|
3787 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3788 |
deprecation_message = (
|
3789 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
@@ -3840,10 +3838,27 @@ class MatryoshkaPipeline(
|
|
3840 |
feature_extractor=feature_extractor,
|
3841 |
image_encoder=image_encoder,
|
3842 |
)
|
|
|
3843 |
if hasattr(unet, "nest_ratio"):
|
3844 |
scheduler.scales = unet.nest_ratio + [1]
|
3845 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3846 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3847 |
def encode_prompt(
|
3848 |
self,
|
3849 |
prompt,
|
|
|
3782 |
else:
|
3783 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3784 |
|
|
|
|
|
3785 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3786 |
deprecation_message = (
|
3787 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
|
|
3838 |
feature_extractor=feature_extractor,
|
3839 |
image_encoder=image_encoder,
|
3840 |
)
|
3841 |
+
self.register_to_config(nesting_level=nesting_level)
|
3842 |
if hasattr(unet, "nest_ratio"):
|
3843 |
scheduler.scales = unet.nest_ratio + [1]
|
3844 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3845 |
|
3846 |
+
def change_nesting_level(self, nesting_level: int):
|
3847 |
+
if nesting_level == 0:
|
3848 |
+
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3849 |
+
subfolder="unet/nesting_level_0")
|
3850 |
+
self.config.nesting_level = 0
|
3851 |
+
elif nesting_level == 1:
|
3852 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3853 |
+
subfolder="unet/nesting_level_1")
|
3854 |
+
self.config.nesting_level = 1
|
3855 |
+
elif nesting_level == 2:
|
3856 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3857 |
+
subfolder="unet/nesting_level_2")
|
3858 |
+
self.config.nesting_level = 2
|
3859 |
+
else:
|
3860 |
+
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3861 |
+
|
3862 |
def encode_prompt(
|
3863 |
self,
|
3864 |
prompt,
|