uruguayai commited on
Commit
6d5f395
·
verified ·
1 Parent(s): f17ed04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -188,7 +188,7 @@ def train_step(state, batch, rng):
188
 
189
  # Use the state's apply_fn (which should be the adjusted UNet)
190
  model_output = state.apply_fn(
191
- {'params': unet_params},
192
  noisy_latents,
193
  jnp.array(timesteps, dtype=jnp.int32),
194
  encoder_hidden_states,
@@ -200,9 +200,15 @@ def train_step(state, batch, rng):
200
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
201
  rng, step_rng = jax.random.split(rng)
202
 
203
- grads = grad_fn(state.params, batch["pixel_values"], step_rng)
204
- loss = compute_loss(state.params, batch["pixel_values"], step_rng)
205
- state = state.apply_gradients(grads=grads)
 
 
 
 
 
 
206
  return state, loss
207
 
208
  # Initialize training state
@@ -220,9 +226,10 @@ adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
220
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
221
  adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
222
 
 
223
  state = train_state.TrainState.create(
224
  apply_fn=adjusted_unet.apply,
225
- params=adjusted_params,
226
  tx=optimizer,
227
  )
228
 
 
188
 
189
  # Use the state's apply_fn (which should be the adjusted UNet)
190
  model_output = state.apply_fn(
191
+ {"params": unet_params},
192
  noisy_latents,
193
  jnp.array(timesteps, dtype=jnp.int32),
194
  encoder_hidden_states,
 
200
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
201
  rng, step_rng = jax.random.split(rng)
202
 
203
+ # Ensure we're passing the correct structure to grad_fn and compute_loss
204
+ unet_params = state.params["params"] if "params" in state.params else state.params
205
+ grads = grad_fn(unet_params, batch["pixel_values"], step_rng)
206
+ loss = compute_loss(unet_params, batch["pixel_values"], step_rng)
207
+
208
+ # Update the state with the correct structure
209
+ new_params = optax.apply_updates(state.params, grads)
210
+ state = state.replace(params=new_params)
211
+
212
  return state, loss
213
 
214
  # Initialize training state
 
226
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
227
  adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
228
 
229
+ # Adjust the state creation
230
  state = train_state.TrainState.create(
231
  apply_fn=adjusted_unet.apply,
232
+ params={"params": adjusted_params}, # Wrap params in a dict with "params" key
233
  tx=optimizer,
234
  )
235