Update app.py
Browse files
app.py
CHANGED
@@ -185,7 +185,9 @@ optimizer = optax.adam(learning_rate)
|
|
185 |
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
-
|
|
|
|
|
189 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
190 |
adjusted_params['params'] = float32_params['unet']
|
191 |
|
|
|
185 |
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
+
unet_config = unet.config.to_dict()
|
189 |
+
unet_config.pop('_use_default_values', None) # Remove the unexpected argument
|
190 |
+
adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
|
191 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
192 |
adjusted_params['params'] = float32_params['unet']
|
193 |
|