tolgacangoz commited on
Commit
7f08644
·
verified ·
1 Parent(s): 05fa96d

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +114 -82
matryoshka.py CHANGED
@@ -664,9 +664,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
664
  variance_noise = []
665
  for m_o in model_output:
666
  variance_noise.append(
667
- randn_tensor(
668
- m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype
669
- )
670
  )
671
  else:
672
  variance_noise = randn_tensor(
@@ -1897,6 +1895,8 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
1897
  dim=1, keepdim=True
1898
  )
1899
  cond_emb = self.cond_emb(y)
 
 
1900
 
1901
  if not masked_cross_attention:
1902
  conditioning_mask = None
@@ -1905,11 +1905,8 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
1905
  if micro is not None:
1906
  temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))
1907
  temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))
1908
- if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1909
- cond_emb_micro = cond_emb + temb_micro_conditioning
1910
- return cond_emb_micro, conditioning_mask, cond_emb
1911
- else:
1912
- return temb_micro_conditioning, conditioning_mask, None
1913
 
1914
  return cond_emb, conditioning_mask, cond_emb
1915
 
@@ -3035,11 +3032,6 @@ class MatryoshkaUNet2DConditionModel(
3035
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3036
  attention_mask = attention_mask.unsqueeze(1)
3037
 
3038
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
3039
- if encoder_attention_mask is not None:
3040
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0
3041
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3042
-
3043
  # 0. center input if necessary
3044
  if self.config.center_input_sample:
3045
  sample = 2 * sample - 1.0
@@ -3074,6 +3066,11 @@ class MatryoshkaUNet2DConditionModel(
3074
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3075
  )
3076
 
 
 
 
 
 
3077
  if self.config.addition_embed_type == "image_hint":
3078
  aug_emb, hint = aug_emb
3079
  sample = torch.cat([sample, hint], dim=1)
@@ -3484,11 +3481,6 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3484
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3485
  attention_mask = attention_mask.unsqueeze(1)
3486
 
3487
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
3488
- if encoder_attention_mask is not None:
3489
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
3490
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3491
-
3492
  # 0. center input if necessary
3493
  if self.config.center_input_sample:
3494
  sample = 2 * sample - 1.0
@@ -3515,15 +3507,15 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3515
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3516
  )
3517
 
3518
- aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.get_aug_embed(
3519
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3520
  )
3521
-
3522
- aug_emb, cond_mask, _ = self.get_aug_embed(
3523
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3524
  )
3525
  else:
3526
- aug_emb, cond_mask_inner_unet, _ = self.get_aug_embed(
3527
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3528
  )
3529
 
@@ -3537,14 +3529,19 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3537
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3538
  )
3539
 
3540
- aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.inner_unet.get_aug_embed(
3541
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3542
  )
3543
 
3544
- aug_emb, cond_mask, _ = self.get_aug_embed(
3545
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3546
  )
3547
 
 
 
 
 
 
3548
  if self.config.addition_embed_type == "image_hint":
3549
  aug_emb, hint = aug_emb
3550
  sample = torch.cat([sample, hint], dim=1)
@@ -3606,7 +3603,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3606
  encoder_hidden_states=encoder_hidden_states[:bh],
3607
  attention_mask=attention_mask,
3608
  cross_attention_kwargs=cross_attention_kwargs,
3609
- encoder_attention_mask=cond_mask_inner_unet[:bh] if cond_mask_inner_unet is not None else cond_mask_inner_unet,
3610
  **additional_residuals,
3611
  )
3612
  else:
@@ -3626,7 +3623,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3626
  timestep,
3627
  cond_emb=cond_emb,
3628
  encoder_hidden_states=encoder_hidden_states,
3629
- encoder_attention_mask=cond_mask_inner_unet,
3630
  from_nested=True,
3631
  )
3632
  x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner
@@ -3914,9 +3911,6 @@ class MatryoshkaPipeline(
3914
 
3915
  text_inputs = self.tokenizer(
3916
  prompt,
3917
- padding="max_length",
3918
- max_length=self.tokenizer.model_max_length,
3919
- truncation=True,
3920
  return_tensors="pt",
3921
  )
3922
  text_input_ids = text_inputs.input_ids
@@ -3934,26 +3928,9 @@ class MatryoshkaPipeline(
3934
  )
3935
 
3936
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3937
- attention_mask = text_inputs.attention_mask.to(device)
3938
  else:
3939
- attention_mask = None
3940
-
3941
- if clip_skip is None:
3942
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
3943
- prompt_embeds = prompt_embeds[0]
3944
- else:
3945
- prompt_embeds = self.text_encoder(
3946
- text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
3947
- )
3948
- # Access the `hidden_states` first, that contains a tuple of
3949
- # all the hidden states from the encoder layers. Then index into
3950
- # the tuple to access the hidden states from the desired layer.
3951
- prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
3952
- # We also need to apply the final LayerNorm here to not mess with the
3953
- # representations. The `last_hidden_states` that we typically use for
3954
- # obtaining the final prompt representations passes through the LayerNorm
3955
- # layer.
3956
- prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
3957
 
3958
  if self.text_encoder is not None:
3959
  prompt_embeds_dtype = self.text_encoder.dtype
@@ -3962,13 +3939,6 @@ class MatryoshkaPipeline(
3962
  else:
3963
  prompt_embeds_dtype = prompt_embeds.dtype
3964
 
3965
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
3966
-
3967
- bs_embed, seq_len, _ = prompt_embeds.shape
3968
- # duplicate text embeddings for each generation per prompt, using mps friendly method
3969
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
3970
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
3971
-
3972
  # get unconditional embeddings for classifier free guidance
3973
  if do_classifier_free_guidance and negative_prompt_embeds is None:
3974
  uncond_tokens: List[str]
@@ -3994,41 +3964,78 @@ class MatryoshkaPipeline(
3994
  if isinstance(self, TextualInversionLoaderMixin):
3995
  uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
3996
 
3997
- max_length = prompt_embeds.shape[1]
3998
  uncond_input = self.tokenizer(
3999
  uncond_tokens,
4000
- padding="max_length",
4001
- max_length=max_length,
4002
- truncation=True,
4003
  return_tensors="pt",
4004
  )
 
4005
 
4006
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
4007
- attention_mask = uncond_input.attention_mask.to(device)
4008
  else:
4009
- attention_mask = None
4010
 
4011
- negative_prompt_embeds = self.text_encoder(
4012
- uncond_input.input_ids.to(device),
4013
- attention_mask=attention_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4014
  )
4015
- negative_prompt_embeds = negative_prompt_embeds[0]
4016
 
4017
- if do_classifier_free_guidance:
4018
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
4019
- seq_len = negative_prompt_embeds.shape[1]
4020
-
4021
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
4022
-
4023
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
4024
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
4025
 
4026
  if self.text_encoder is not None:
4027
  if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
4028
  # Retrieve the original scale by scaling back the LoRA layers
4029
  unscale_lora_layers(self.text_encoder, lora_scale)
4030
 
4031
- return prompt_embeds, negative_prompt_embeds, attention_mask
 
 
4032
 
4033
  def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
4034
  dtype = next(self.image_encoder.parameters()).dtype
@@ -4461,7 +4468,12 @@ class MatryoshkaPipeline(
4461
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
4462
  )
4463
 
4464
- prompt_embeds, negative_prompt_embeds, encoder_attention_mask = self.encode_prompt(
 
 
 
 
 
4465
  prompt,
4466
  device,
4467
  num_images_per_prompt,
@@ -4477,7 +4489,12 @@ class MatryoshkaPipeline(
4477
  # Here we concatenate the unconditional and text embeddings into a single batch
4478
  # to avoid doing two forward passes
4479
  if self.do_classifier_free_guidance:
4480
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
 
 
 
 
 
4481
 
4482
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
4483
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -4489,10 +4506,13 @@ class MatryoshkaPipeline(
4489
  )
4490
 
4491
  # 4. Prepare timesteps
4492
- timesteps, num_inference_steps = retrieve_timesteps(
4493
- self.scheduler, num_inference_steps, device, timesteps, sigmas
4494
- )
4495
- timesteps = timesteps[:-1]
 
 
 
4496
 
4497
  # 5. Prepare latent variables
4498
  num_channels_latents = self.unet.config.in_channels
@@ -4551,7 +4571,7 @@ class MatryoshkaPipeline(
4551
  timestep_cond=timestep_cond,
4552
  cross_attention_kwargs=self.cross_attention_kwargs,
4553
  added_cond_kwargs=added_cond_kwargs,
4554
- encoder_attention_mask=encoder_attention_mask,
4555
  return_dict=False,
4556
  )[0]
4557
 
@@ -4568,7 +4588,19 @@ class MatryoshkaPipeline(
4568
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
4569
 
4570
  # compute the previous noisy sample x_t -> x_t-1
4571
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
4572
 
4573
  if callback_on_step_end is not None:
4574
  callback_kwargs = {}
 
664
  variance_noise = []
665
  for m_o in model_output:
666
  variance_noise.append(
667
+ randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype)
 
 
668
  )
669
  else:
670
  variance_noise = randn_tensor(
 
1895
  dim=1, keepdim=True
1896
  )
1897
  cond_emb = self.cond_emb(y)
1898
+ else:
1899
+ cond_emb = None
1900
 
1901
  if not masked_cross_attention:
1902
  conditioning_mask = None
 
1905
  if micro is not None:
1906
  temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))
1907
  temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))
1908
+ # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1909
+ return temb_micro_conditioning, conditioning_mask, cond_emb
 
 
 
1910
 
1911
  return cond_emb, conditioning_mask, cond_emb
1912
 
 
3032
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3033
  attention_mask = attention_mask.unsqueeze(1)
3034
 
 
 
 
 
 
3035
  # 0. center input if necessary
3036
  if self.config.center_input_sample:
3037
  sample = 2 * sample - 1.0
 
3066
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3067
  )
3068
 
3069
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
3070
+ if encoder_attention_mask is not None:
3071
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0
3072
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3073
+
3074
  if self.config.addition_embed_type == "image_hint":
3075
  aug_emb, hint = aug_emb
3076
  sample = torch.cat([sample, hint], dim=1)
 
3481
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3482
  attention_mask = attention_mask.unsqueeze(1)
3483
 
 
 
 
 
 
3484
  # 0. center input if necessary
3485
  if self.config.center_input_sample:
3486
  sample = 2 * sample - 1.0
 
3507
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3508
  )
3509
 
3510
+ aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed(
3511
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3512
  )
3513
+ added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
3514
+ aug_emb, __, _ = self.get_aug_embed(
3515
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3516
  )
3517
  else:
3518
+ aug_emb, cond_mask, _ = self.get_aug_embed(
3519
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3520
  )
3521
 
 
3529
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3530
  )
3531
 
3532
+ aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed(
3533
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3534
  )
3535
 
3536
+ aug_emb, __, _ = self.get_aug_embed(
3537
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3538
  )
3539
 
3540
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
3541
+ if encoder_attention_mask is not None:
3542
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
3543
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3544
+
3545
  if self.config.addition_embed_type == "image_hint":
3546
  aug_emb, hint = aug_emb
3547
  sample = torch.cat([sample, hint], dim=1)
 
3603
  encoder_hidden_states=encoder_hidden_states[:bh],
3604
  attention_mask=attention_mask,
3605
  cross_attention_kwargs=cross_attention_kwargs,
3606
+ encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
3607
  **additional_residuals,
3608
  )
3609
  else:
 
3623
  timestep,
3624
  cond_emb=cond_emb,
3625
  encoder_hidden_states=encoder_hidden_states,
3626
+ encoder_attention_mask=cond_mask,
3627
  from_nested=True,
3628
  )
3629
  x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner
 
3911
 
3912
  text_inputs = self.tokenizer(
3913
  prompt,
 
 
 
3914
  return_tensors="pt",
3915
  )
3916
  text_input_ids = text_inputs.input_ids
 
3928
  )
3929
 
3930
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3931
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
3932
  else:
3933
+ prompt_attention_mask = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3934
 
3935
  if self.text_encoder is not None:
3936
  prompt_embeds_dtype = self.text_encoder.dtype
 
3939
  else:
3940
  prompt_embeds_dtype = prompt_embeds.dtype
3941
 
 
 
 
 
 
 
 
3942
  # get unconditional embeddings for classifier free guidance
3943
  if do_classifier_free_guidance and negative_prompt_embeds is None:
3944
  uncond_tokens: List[str]
 
3964
  if isinstance(self, TextualInversionLoaderMixin):
3965
  uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
3966
 
 
3967
  uncond_input = self.tokenizer(
3968
  uncond_tokens,
 
 
 
3969
  return_tensors="pt",
3970
  )
3971
+ uncond_input_ids = uncond_input.input_ids
3972
 
3973
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3974
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
3975
  else:
3976
+ negative_prompt_attention_mask = None
3977
 
3978
+ if not do_classifier_free_guidance:
3979
+ if clip_skip is None:
3980
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
3981
+ prompt_embeds = prompt_embeds[0]
3982
+ else:
3983
+ prompt_embeds = self.text_encoder(
3984
+ text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True
3985
+ )
3986
+ # Access the `hidden_states` first, that contains a tuple of
3987
+ # all the hidden states from the encoder layers. Then index into
3988
+ # the tuple to access the hidden states from the desired layer.
3989
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
3990
+ # We also need to apply the final LayerNorm here to not mess with the
3991
+ # representations. The `last_hidden_states` that we typically use for
3992
+ # obtaining the final prompt representations passes through the LayerNorm
3993
+ # layer.
3994
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
3995
+ else:
3996
+ max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0]))
3997
+ if len(text_input_ids[0]) < max_len:
3998
+ text_input_ids = torch.cat(
3999
+ [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)],
4000
+ dim=1,
4001
+ )
4002
+ prompt_attention_mask = torch.cat(
4003
+ [
4004
+ prompt_attention_mask,
4005
+ torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long),
4006
+ ],
4007
+ dim=1,
4008
+ )
4009
+ elif len(uncond_input_ids[0]) < max_len:
4010
+ uncond_input_ids = torch.cat(
4011
+ [uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)],
4012
+ dim=1,
4013
+ )
4014
+ negative_prompt_attention_mask = torch.cat(
4015
+ [
4016
+ negative_prompt_attention_mask,
4017
+ torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long),
4018
+ ],
4019
+ dim=1,
4020
+ )
4021
+ cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0)
4022
+ cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
4023
+ prompt_embeds = self.text_encoder(
4024
+ cfg_input_ids.to(device),
4025
+ attention_mask=cfg_attention_mask,
4026
  )
4027
+ prompt_embeds = prompt_embeds[0]
4028
 
4029
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
 
 
 
 
 
 
 
4030
 
4031
  if self.text_encoder is not None:
4032
  if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
4033
  # Retrieve the original scale by scaling back the LoRA layers
4034
  unscale_lora_layers(self.text_encoder, lora_scale)
4035
 
4036
+ if not do_classifier_free_guidance:
4037
+ return prompt_embeds, None, prompt_attention_mask, None
4038
+ return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask
4039
 
4040
  def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
4041
  dtype = next(self.image_encoder.parameters()).dtype
 
4468
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
4469
  )
4470
 
4471
+ (
4472
+ prompt_embeds,
4473
+ negative_prompt_embeds,
4474
+ prompt_attention_mask,
4475
+ negative_prompt_attention_mask,
4476
+ ) = self.encode_prompt(
4477
  prompt,
4478
  device,
4479
  num_images_per_prompt,
 
4489
  # Here we concatenate the unconditional and text embeddings into a single batch
4490
  # to avoid doing two forward passes
4491
  if self.do_classifier_free_guidance:
4492
+ prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)])
4493
+ attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
4494
+ else:
4495
+ attention_masks = prompt_attention_mask
4496
+
4497
+ prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1)
4498
 
4499
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
4500
  image_embeds = self.prepare_ip_adapter_image_embeds(
 
4506
  )
4507
 
4508
  # 4. Prepare timesteps
4509
+ if isinstance(self.scheduler, MatryoshkaDDIMScheduler):
4510
+ timesteps, num_inference_steps = retrieve_timesteps(
4511
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
4512
+ )
4513
+ timesteps = timesteps[:-1] # is this correct???
4514
+ else:
4515
+ timesteps = self.scheduler.timesteps
4516
 
4517
  # 5. Prepare latent variables
4518
  num_channels_latents = self.unet.config.in_channels
 
4571
  timestep_cond=timestep_cond,
4572
  cross_attention_kwargs=self.cross_attention_kwargs,
4573
  added_cond_kwargs=added_cond_kwargs,
4574
+ encoder_attention_mask=attention_masks,
4575
  return_dict=False,
4576
  )[0]
4577
 
 
4588
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
4589
 
4590
  # compute the previous noisy sample x_t -> x_t-1
4591
+ if self.scheduler.scales is not None and not isinstance(self.scheduler, MatryoshkaDDIMScheduler):
4592
+ latents[0] = self.scheduler.step(
4593
+ noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False
4594
+ )[0]
4595
+ latents[1] = self.scheduler.inner_scheduler.step(
4596
+ noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False
4597
+ )[0]
4598
+ if len(latents) > 2:
4599
+ latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step(
4600
+ noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False
4601
+ )[0]
4602
+ else:
4603
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
4604
 
4605
  if callback_on_step_end is not None:
4606
  callback_kwargs = {}