Update app.py
Browse files
app.py
CHANGED
@@ -138,7 +138,7 @@ def train_step(state, batch, rng):
|
|
138 |
).latent_dist.sample(rng)
|
139 |
latents = latents * jnp.float32(0.18215)
|
140 |
print(f"latents shape: {latents.shape}")
|
141 |
-
|
142 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
143 |
|
144 |
timesteps = jax.random.randint(
|
@@ -172,7 +172,7 @@ def train_step(state, batch, rng):
|
|
172 |
).sample
|
173 |
|
174 |
return jnp.mean((model_output - noise) ** 2)
|
175 |
-
|
176 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
177 |
rng, step_rng = jax.random.split(rng)
|
178 |
|
@@ -181,11 +181,6 @@ def train_step(state, batch, rng):
|
|
181 |
state = state.apply_gradients(grads=grads)
|
182 |
return state, loss
|
183 |
|
184 |
-
def filter_dict(dict_to_filter, target_callable):
|
185 |
-
"""Filter a dictionary to only include keys that are valid parameters for the target callable."""
|
186 |
-
valid_params = signature(target_callable).parameters.keys()
|
187 |
-
return {k: v for k, v in dict_to_filter.items() if k in valid_params}
|
188 |
-
|
189 |
# Initialize training state
|
190 |
learning_rate = 1e-5
|
191 |
optimizer = optax.adam(learning_rate)
|
@@ -199,7 +194,7 @@ print("Filtered UNet config keys:", filtered_unet_config.keys())
|
|
199 |
|
200 |
adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
|
201 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
202 |
-
adjusted_params
|
203 |
|
204 |
state = train_state.TrainState.create(
|
205 |
apply_fn=adjusted_unet.apply,
|
|
|
138 |
).latent_dist.sample(rng)
|
139 |
latents = latents * jnp.float32(0.18215)
|
140 |
print(f"latents shape: {latents.shape}")
|
141 |
+
|
142 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
143 |
|
144 |
timesteps = jax.random.randint(
|
|
|
172 |
).sample
|
173 |
|
174 |
return jnp.mean((model_output - noise) ** 2)
|
175 |
+
|
176 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
177 |
rng, step_rng = jax.random.split(rng)
|
178 |
|
|
|
181 |
state = state.apply_gradients(grads=grads)
|
182 |
return state, loss
|
183 |
|
|
|
|
|
|
|
|
|
|
|
184 |
# Initialize training state
|
185 |
learning_rate = 1e-5
|
186 |
optimizer = optax.adam(learning_rate)
|
|
|
194 |
|
195 |
adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
|
196 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
197 |
+
adjusted_params = float32_params['unet'] # Use only UNet params
|
198 |
|
199 |
state = train_state.TrainState.create(
|
200 |
apply_fn=adjusted_unet.apply,
|