tolgacangoz commited on
Commit
57f9d0f
·
verified ·
1 Parent(s): e30e6fa

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +12 -1
matryoshka.py CHANGED
@@ -3746,11 +3746,12 @@ class MatryoshkaPipeline(
3746
  self,
3747
  text_encoder: T5EncoderModel,
3748
  tokenizer: T5TokenizerFast,
3749
- unet: MatryoshkaUNet2DConditionModel,
3750
  scheduler: MatryoshkaDDIMScheduler,
 
3751
  feature_extractor: CLIPImageProcessor = None,
3752
  image_encoder: CLIPVisionModelWithProjection = None,
3753
  trust_remote_code: bool = False,
 
3754
  ):
3755
  super().__init__()
3756
 
@@ -3802,6 +3803,16 @@ class MatryoshkaPipeline(
3802
  new_config["sample_size"] = 64
3803
  unet._internal_dict = FrozenDict(new_config)
3804
 
 
 
 
 
 
 
 
 
 
 
3805
  self.register_modules(
3806
  text_encoder=text_encoder,
3807
  tokenizer=tokenizer,
 
3746
  self,
3747
  text_encoder: T5EncoderModel,
3748
  tokenizer: T5TokenizerFast,
 
3749
  scheduler: MatryoshkaDDIMScheduler,
3750
+ unet: MatryoshkaUNet2DConditionModel = None,
3751
  feature_extractor: CLIPImageProcessor = None,
3752
  image_encoder: CLIPVisionModelWithProjection = None,
3753
  trust_remote_code: bool = False,
3754
+ nesting_level: int = 0,
3755
  ):
3756
  super().__init__()
3757
 
 
3803
  new_config["sample_size"] = 64
3804
  unet._internal_dict = FrozenDict(new_config)
3805
 
3806
+ if nesting_level == 0:
3807
+ unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3808
+ subfolder="unet/nesting_level_0")
3809
+ elif nesting_level == 1:
3810
+ unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3811
+ subfolder="unet/nesting_level_1")
3812
+ elif nesting_level == 2:
3813
+ unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3814
+ subfolder="unet/nesting_level_2")
3815
+
3816
  self.register_modules(
3817
  text_encoder=text_encoder,
3818
  tokenizer=tokenizer,