uruguayai commited on
Commit
ceeeb32
·
verified ·
1 Parent(s): b2ad618

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -138,7 +138,7 @@ def train_step(state, batch, rng):
138
  ).latent_dist.sample(rng)
139
  latents = latents * jnp.float32(0.18215)
140
  print(f"latents shape: {latents.shape}")
141
-
142
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
143
 
144
  timesteps = jax.random.randint(
@@ -172,7 +172,7 @@ def train_step(state, batch, rng):
172
  ).sample
173
 
174
  return jnp.mean((model_output - noise) ** 2)
175
-
176
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
177
  rng, step_rng = jax.random.split(rng)
178
 
@@ -181,11 +181,6 @@ def train_step(state, batch, rng):
181
  state = state.apply_gradients(grads=grads)
182
  return state, loss
183
 
184
- def filter_dict(dict_to_filter, target_callable):
185
- """Filter a dictionary to only include keys that are valid parameters for the target callable."""
186
- valid_params = signature(target_callable).parameters.keys()
187
- return {k: v for k, v in dict_to_filter.items() if k in valid_params}
188
-
189
  # Initialize training state
190
  learning_rate = 1e-5
191
  optimizer = optax.adam(learning_rate)
@@ -199,7 +194,7 @@ print("Filtered UNet config keys:", filtered_unet_config.keys())
199
 
200
  adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
201
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
202
- adjusted_params['params'] = float32_params['unet']
203
 
204
  state = train_state.TrainState.create(
205
  apply_fn=adjusted_unet.apply,
 
138
  ).latent_dist.sample(rng)
139
  latents = latents * jnp.float32(0.18215)
140
  print(f"latents shape: {latents.shape}")
141
+
142
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
143
 
144
  timesteps = jax.random.randint(
 
172
  ).sample
173
 
174
  return jnp.mean((model_output - noise) ** 2)
175
+
176
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
177
  rng, step_rng = jax.random.split(rng)
178
 
 
181
  state = state.apply_gradients(grads=grads)
182
  return state, loss
183
 
 
 
 
 
 
184
  # Initialize training state
185
  learning_rate = 1e-5
186
  optimizer = optax.adam(learning_rate)
 
194
 
195
  adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
196
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
197
+ adjusted_params = float32_params['unet'] # Use only UNet params
198
 
199
  state = train_state.TrainState.create(
200
  apply_fn=adjusted_unet.apply,