tolgacangoz commited on
Commit
54b3ffb
·
verified ·
1 Parent(s): 42aafad

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +3 -1
matryoshka.py CHANGED
@@ -3762,7 +3762,7 @@ class MatryoshkaPipeline(
3762
  text_encoder: T5EncoderModel,
3763
  tokenizer: T5TokenizerFast,
3764
  scheduler: MatryoshkaDDIMScheduler,
3765
- unet: MatryoshkaUNet2DConditionModel,
3766
  feature_extractor: CLIPImageProcessor = None,
3767
  image_encoder: CLIPVisionModelWithProjection = None,
3768
  trust_remote_code: bool = False,
@@ -3782,6 +3782,8 @@ class MatryoshkaPipeline(
3782
  else:
3783
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3784
 
 
 
3785
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3786
  deprecation_message = (
3787
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
 
3762
  text_encoder: T5EncoderModel,
3763
  tokenizer: T5TokenizerFast,
3764
  scheduler: MatryoshkaDDIMScheduler,
3765
+ unet: MatryoshkaUNet2DConditionModel = None,
3766
  feature_extractor: CLIPImageProcessor = None,
3767
  image_encoder: CLIPVisionModelWithProjection = None,
3768
  trust_remote_code: bool = False,
 
3782
  else:
3783
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3784
 
3785
+ unet = unet.to(self.device)
3786
+
3787
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3788
  deprecation_message = (
3789
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"