uruguayai commited on
Commit
9d9591c
·
verified ·
1 Parent(s): e0e747f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -185,7 +185,9 @@ optimizer = optax.adam(learning_rate)
185
  float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
186
 
187
  # Create a new UNet with the adjusted parameters
188
- adjusted_unet = FlaxUNet2DConditionModel(**unet.config)
 
 
189
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
190
  adjusted_params['params'] = float32_params['unet']
191
 
 
185
  float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
186
 
187
  # Create a new UNet with the adjusted parameters
188
+ unet_config = unet.config.to_dict()
189
+ unet_config.pop('_use_default_values', None) # Remove the unexpected argument
190
+ adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
191
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
192
  adjusted_params['params'] = float32_params['unet']
193