Update app.py
Browse files
app.py
CHANGED
@@ -186,7 +186,11 @@ float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dty
|
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
unet_config = dict(unet.config)
|
189 |
-
|
|
|
|
|
|
|
|
|
190 |
adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
|
191 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
192 |
adjusted_params['params'] = float32_params['unet']
|
|
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
unet_config = dict(unet.config)
|
189 |
+
# Remove unexpected keys
|
190 |
+
unexpected_keys = ['_class_name', '_diffusers_version', '_use_default_values', '_name_or_path']
|
191 |
+
for key in unexpected_keys:
|
192 |
+
unet_config.pop(key, None)
|
193 |
+
|
194 |
adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
|
195 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
196 |
adjusted_params['params'] = float32_params['unet']
|