Update app.py
Browse files
app.py
CHANGED
@@ -116,6 +116,9 @@ def train_step(state, batch, rng):
|
|
116 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
117 |
)
|
118 |
|
|
|
|
|
|
|
119 |
# Add noise to latents
|
120 |
noisy_latents = pipeline.scheduler.add_noise(
|
121 |
pipeline.scheduler.create_state(),
|
|
|
116 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
117 |
)
|
118 |
|
119 |
+
# Explicitly cast timesteps to int32
|
120 |
+
timesteps = timesteps.astype(jnp.int32)
|
121 |
+
|
122 |
# Add noise to latents
|
123 |
noisy_latents = pipeline.scheduler.add_noise(
|
124 |
pipeline.scheduler.create_state(),
|