Update app.py
Browse files
app.py
CHANGED
@@ -188,7 +188,7 @@ def train_step(state, batch, rng):
|
|
188 |
|
189 |
# Use the state's apply_fn (which should be the adjusted UNet)
|
190 |
model_output = state.apply_fn(
|
191 |
-
{
|
192 |
noisy_latents,
|
193 |
jnp.array(timesteps, dtype=jnp.int32),
|
194 |
encoder_hidden_states,
|
@@ -200,9 +200,15 @@ def train_step(state, batch, rng):
|
|
200 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
201 |
rng, step_rng = jax.random.split(rng)
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
return state, loss
|
207 |
|
208 |
# Initialize training state
|
@@ -220,9 +226,10 @@ adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
|
|
220 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
221 |
adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
|
222 |
|
|
|
223 |
state = train_state.TrainState.create(
|
224 |
apply_fn=adjusted_unet.apply,
|
225 |
-
params=adjusted_params,
|
226 |
tx=optimizer,
|
227 |
)
|
228 |
|
|
|
188 |
|
189 |
# Use the state's apply_fn (which should be the adjusted UNet)
|
190 |
model_output = state.apply_fn(
|
191 |
+
{"params": unet_params},
|
192 |
noisy_latents,
|
193 |
jnp.array(timesteps, dtype=jnp.int32),
|
194 |
encoder_hidden_states,
|
|
|
200 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
201 |
rng, step_rng = jax.random.split(rng)
|
202 |
|
203 |
+
# Ensure we're passing the correct structure to grad_fn and compute_loss
|
204 |
+
unet_params = state.params["params"] if "params" in state.params else state.params
|
205 |
+
grads = grad_fn(unet_params, batch["pixel_values"], step_rng)
|
206 |
+
loss = compute_loss(unet_params, batch["pixel_values"], step_rng)
|
207 |
+
|
208 |
+
# Update the state with the correct structure
|
209 |
+
new_params = optax.apply_updates(state.params, grads)
|
210 |
+
state = state.replace(params=new_params)
|
211 |
+
|
212 |
return state, loss
|
213 |
|
214 |
# Initialize training state
|
|
|
226 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
227 |
adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
|
228 |
|
229 |
+
# Adjust the state creation
|
230 |
state = train_state.TrainState.create(
|
231 |
apply_fn=adjusted_unet.apply,
|
232 |
+
params={"params": adjusted_params}, # Wrap params in a dict with "params" key
|
233 |
tx=optimizer,
|
234 |
)
|
235 |
|