tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +12 -1
matryoshka.py
CHANGED
@@ -3746,11 +3746,12 @@ class MatryoshkaPipeline(
|
|
3746 |
self,
|
3747 |
text_encoder: T5EncoderModel,
|
3748 |
tokenizer: T5TokenizerFast,
|
3749 |
-
unet: MatryoshkaUNet2DConditionModel,
|
3750 |
scheduler: MatryoshkaDDIMScheduler,
|
|
|
3751 |
feature_extractor: CLIPImageProcessor = None,
|
3752 |
image_encoder: CLIPVisionModelWithProjection = None,
|
3753 |
trust_remote_code: bool = False,
|
|
|
3754 |
):
|
3755 |
super().__init__()
|
3756 |
|
@@ -3802,6 +3803,16 @@ class MatryoshkaPipeline(
|
|
3802 |
new_config["sample_size"] = 64
|
3803 |
unet._internal_dict = FrozenDict(new_config)
|
3804 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3805 |
self.register_modules(
|
3806 |
text_encoder=text_encoder,
|
3807 |
tokenizer=tokenizer,
|
|
|
3746 |
self,
|
3747 |
text_encoder: T5EncoderModel,
|
3748 |
tokenizer: T5TokenizerFast,
|
|
|
3749 |
scheduler: MatryoshkaDDIMScheduler,
|
3750 |
+
unet: MatryoshkaUNet2DConditionModel = None,
|
3751 |
feature_extractor: CLIPImageProcessor = None,
|
3752 |
image_encoder: CLIPVisionModelWithProjection = None,
|
3753 |
trust_remote_code: bool = False,
|
3754 |
+
nesting_level: int = 0,
|
3755 |
):
|
3756 |
super().__init__()
|
3757 |
|
|
|
3803 |
new_config["sample_size"] = 64
|
3804 |
unet._internal_dict = FrozenDict(new_config)
|
3805 |
|
3806 |
+
if nesting_level == 0:
|
3807 |
+
unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3808 |
+
subfolder="unet/nesting_level_0")
|
3809 |
+
elif nesting_level == 1:
|
3810 |
+
unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3811 |
+
subfolder="unet/nesting_level_1")
|
3812 |
+
elif nesting_level == 2:
|
3813 |
+
unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3814 |
+
subfolder="unet/nesting_level_2")
|
3815 |
+
|
3816 |
self.register_modules(
|
3817 |
text_encoder=text_encoder,
|
3818 |
tokenizer=tokenizer,
|