tolgacangoz commited on
Commit
6f3dc35
·
verified ·
1 Parent(s): f2aed2b

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +6 -1
matryoshka.py CHANGED
@@ -654,7 +654,12 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
654
 
655
  if use_clipped_model_output:
656
  # the pred_epsilon is always re-derived from the clipped x_0 in Glide
657
- pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
 
 
 
 
 
658
 
659
  # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
660
  if len(model_output) > 1:
 
654
 
655
  if use_clipped_model_output:
656
  # the pred_epsilon is always re-derived from the clipped x_0 in Glide
657
+ if len(model_output) > 1:
658
+ pred_epsilon = []
659
+ for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t):
660
+ pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5))
661
+ else:
662
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
663
 
664
  # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
665
  if len(model_output) > 1: