tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +6 -1
matryoshka.py
CHANGED
@@ -654,7 +654,12 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
654 |
|
655 |
if use_clipped_model_output:
|
656 |
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
657 |
-
|
|
|
|
|
|
|
|
|
|
|
658 |
|
659 |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
660 |
if len(model_output) > 1:
|
|
|
654 |
|
655 |
if use_clipped_model_output:
|
656 |
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
657 |
+
if len(model_output) > 1:
|
658 |
+
pred_epsilon = []
|
659 |
+
for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t):
|
660 |
+
pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5))
|
661 |
+
else:
|
662 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
663 |
|
664 |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
665 |
if len(model_output) > 1:
|