uruguayai commited on
Commit
1dab2a9
·
verified ·
1 Parent(s): 8dd6063

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -186,7 +186,11 @@ float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dty
186
 
187
  # Create a new UNet with the adjusted parameters
188
  unet_config = dict(unet.config)
189
- unet_config.pop('_use_default_values', None) # Remove the unexpected argument
 
 
 
 
190
  adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
191
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
192
  adjusted_params['params'] = float32_params['unet']
 
186
 
187
  # Create a new UNet with the adjusted parameters
188
  unet_config = dict(unet.config)
189
+ # Remove unexpected keys
190
+ unexpected_keys = ['_class_name', '_diffusers_version', '_use_default_values', '_name_or_path']
191
+ for key in unexpected_keys:
192
+ unet_config.pop(key, None)
193
+
194
  adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
195
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
196
  adjusted_params['params'] = float32_params['unet']