uruguayai commited on
Commit
00f4326
·
verified ·
1 Parent(s): 4434e29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -155,23 +155,27 @@ state = train_state.TrainState.create(
155
  )
156
 
157
  # Modify the train_step function
158
- def train_step(state, batch, rng):
159
  def compute_loss(params, pixel_values, rng):
 
 
 
160
  # Encode images to latent space
161
  latents = pipeline.vae.apply(
162
  {"params": params["vae"]},
163
  pixel_values,
164
  method=pipeline.vae.encode
165
  ).latent_dist.sample(rng)
166
- latents = latents * 0.18215
167
 
168
  # Generate random noise
169
- noise = jax.random.normal(rng, latents.shape)
170
 
171
  # Sample random timesteps
172
  timesteps = jax.random.randint(
173
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
174
  )
 
175
 
176
  # Add noise to latents
177
  noisy_latents = pipeline.scheduler.add_noise(
@@ -184,7 +188,8 @@ def train_step(state, batch, rng):
184
  # Generate random encoder hidden states (simulating text embeddings)
185
  encoder_hidden_states = jax.random.normal(
186
  rng,
187
- (latents.shape[0], pipeline.text_encoder.config.hidden_size)
 
188
  )
189
 
190
  # Predict noise
@@ -199,12 +204,16 @@ def train_step(state, batch, rng):
199
  # Compute loss
200
  return jnp.mean((model_output - noise) ** 2)
201
 
202
- grad_fn = jax.value_and_grad(compute_loss)
203
  rng, step_rng = jax.random.split(rng)
204
- loss, grads = grad_fn(state.params, batch["pixel_values"], step_rng)
 
 
205
  state = state.apply_gradients(grads=grads)
206
  return state, loss
207
 
 
 
208
  # Training loop (remains the same)
209
  num_epochs = 3
210
  batch_size = 1
 
155
  )
156
 
157
  # Modify the train_step function
158
+ ef train_step(state, batch, rng):
159
  def compute_loss(params, pixel_values, rng):
160
+ # Ensure pixel_values are float32
161
+ pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
162
+
163
  # Encode images to latent space
164
  latents = pipeline.vae.apply(
165
  {"params": params["vae"]},
166
  pixel_values,
167
  method=pipeline.vae.encode
168
  ).latent_dist.sample(rng)
169
+ latents = latents * jnp.float32(0.18215)
170
 
171
  # Generate random noise
172
+ noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
173
 
174
  # Sample random timesteps
175
  timesteps = jax.random.randint(
176
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
177
  )
178
+ timesteps = jnp.array(timesteps, dtype=jnp.float32)
179
 
180
  # Add noise to latents
181
  noisy_latents = pipeline.scheduler.add_noise(
 
188
  # Generate random encoder hidden states (simulating text embeddings)
189
  encoder_hidden_states = jax.random.normal(
190
  rng,
191
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size),
192
+ dtype=jnp.float32
193
  )
194
 
195
  # Predict noise
 
204
  # Compute loss
205
  return jnp.mean((model_output - noise) ** 2)
206
 
207
+ grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
208
  rng, step_rng = jax.random.split(rng)
209
+
210
+ grads = grad_fn(state.params, batch["pixel_values"], step_rng)
211
+ loss = compute_loss(state.params, batch["pixel_values"], step_rng)
212
  state = state.apply_gradients(grads=grads)
213
  return state, loss
214
 
215
+
216
+
217
  # Training loop (remains the same)
218
  num_epochs = 3
219
  batch_size = 1