tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +3 -1
matryoshka.py
CHANGED
@@ -3762,7 +3762,7 @@ class MatryoshkaPipeline(
|
|
3762 |
text_encoder: T5EncoderModel,
|
3763 |
tokenizer: T5TokenizerFast,
|
3764 |
scheduler: MatryoshkaDDIMScheduler,
|
3765 |
-
unet: MatryoshkaUNet2DConditionModel,
|
3766 |
feature_extractor: CLIPImageProcessor = None,
|
3767 |
image_encoder: CLIPVisionModelWithProjection = None,
|
3768 |
trust_remote_code: bool = False,
|
@@ -3782,6 +3782,8 @@ class MatryoshkaPipeline(
|
|
3782 |
else:
|
3783 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3784 |
|
|
|
|
|
3785 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3786 |
deprecation_message = (
|
3787 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
|
|
3762 |
text_encoder: T5EncoderModel,
|
3763 |
tokenizer: T5TokenizerFast,
|
3764 |
scheduler: MatryoshkaDDIMScheduler,
|
3765 |
+
unet: MatryoshkaUNet2DConditionModel = None,
|
3766 |
feature_extractor: CLIPImageProcessor = None,
|
3767 |
image_encoder: CLIPVisionModelWithProjection = None,
|
3768 |
trust_remote_code: bool = False,
|
|
|
3782 |
else:
|
3783 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3784 |
|
3785 |
+
unet = unet.to(self.device)
|
3786 |
+
|
3787 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3788 |
deprecation_message = (
|
3789 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|