Update app.py
Browse files
app.py
CHANGED
@@ -136,7 +136,7 @@ def train_step(state, batch, rng):
|
|
136 |
)
|
137 |
|
138 |
# Predict noise
|
139 |
-
model_output =
|
140 |
{'params': params["unet"]},
|
141 |
noisy_latents,
|
142 |
timesteps,
|
@@ -160,7 +160,7 @@ learning_rate = 1e-5
|
|
160 |
optimizer = optax.adam(learning_rate)
|
161 |
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
162 |
state = train_state.TrainState.create(
|
163 |
-
apply_fn=unet.
|
164 |
params=float32_params,
|
165 |
tx=optimizer,
|
166 |
)
|
|
|
136 |
)
|
137 |
|
138 |
# Predict noise
|
139 |
+
model_output = unet.apply(
|
140 |
{'params': params["unet"]},
|
141 |
noisy_latents,
|
142 |
timesteps,
|
|
|
160 |
optimizer = optax.adam(learning_rate)
|
161 |
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
162 |
state = train_state.TrainState.create(
|
163 |
+
apply_fn=unet.apply,
|
164 |
params=float32_params,
|
165 |
tx=optimizer,
|
166 |
)
|