uruguayai commited on
Commit
cf50961
·
verified ·
1 Parent(s): 571b479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -116,6 +116,9 @@ def train_step(state, batch, rng):
116
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
117
  )
118
 
 
 
 
119
  # Add noise to latents
120
  noisy_latents = pipeline.scheduler.add_noise(
121
  pipeline.scheduler.create_state(),
 
116
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
117
  )
118
 
119
+ # Explicitly cast timesteps to int32
120
+ timesteps = timesteps.astype(jnp.int32)
121
+
122
  # Add noise to latents
123
  noisy_latents = pipeline.scheduler.add_noise(
124
  pipeline.scheduler.create_state(),