Update app.py
Browse files
app.py
CHANGED
@@ -185,7 +185,7 @@ optimizer = optax.adam(learning_rate)
|
|
185 |
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
-
unet_config = unet.config
|
189 |
unet_config.pop('_use_default_values', None) # Remove the unexpected argument
|
190 |
adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
|
191 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
|
|
185 |
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
+
unet_config = dict(unet.config)
|
189 |
unet_config.pop('_use_default_values', None) # Remove the unexpected argument
|
190 |
adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
|
191 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|