uruguayai commited on
Commit
8dd6063
·
verified ·
1 Parent(s): 9d9591c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -185,7 +185,7 @@ optimizer = optax.adam(learning_rate)
185
  float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
186
 
187
  # Create a new UNet with the adjusted parameters
188
- unet_config = unet.config.to_dict()
189
  unet_config.pop('_use_default_values', None) # Remove the unexpected argument
190
  adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
191
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
 
185
  float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
186
 
187
  # Create a new UNet with the adjusted parameters
188
+ unet_config = dict(unet.config)
189
  unet_config.pop('_use_default_values', None) # Remove the unexpected argument
190
  adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
191
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))