uruguayai commited on
Commit
8a49030
·
verified ·
1 Parent(s): 1dab2a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -11,6 +11,7 @@ import os
11
  import pickle
12
  from PIL import Image
13
  import numpy as np
 
14
 
15
  # Custom Scheduler
16
  class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
@@ -179,6 +180,11 @@ def train_step(state, batch, rng):
179
  state = state.apply_gradients(grads=grads)
180
  return state, loss
181
 
 
 
 
 
 
182
  # Initialize training state
183
  learning_rate = 1e-5
184
  optimizer = optax.adam(learning_rate)
@@ -186,12 +192,11 @@ float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dty
186
 
187
  # Create a new UNet with the adjusted parameters
188
  unet_config = dict(unet.config)
189
- # Remove unexpected keys
190
- unexpected_keys = ['_class_name', '_diffusers_version', '_use_default_values', '_name_or_path']
191
- for key in unexpected_keys:
192
- unet_config.pop(key, None)
193
 
194
- adjusted_unet = FlaxUNet2DConditionModel(**unet_config)
195
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
196
  adjusted_params['params'] = float32_params['unet']
197
 
 
11
  import pickle
12
  from PIL import Image
13
  import numpy as np
14
+ from inspect import signature
15
 
16
  # Custom Scheduler
17
  class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
 
180
  state = state.apply_gradients(grads=grads)
181
  return state, loss
182
 
183
+ def filter_dict(dict_to_filter, target_callable):
184
+ """Filter a dictionary to only include keys that are valid parameters for the target callable."""
185
+ valid_params = signature(target_callable).parameters.keys()
186
+ return {k: v for k, v in dict_to_filter.items() if k in valid_params}
187
+
188
  # Initialize training state
189
  learning_rate = 1e-5
190
  optimizer = optax.adam(learning_rate)
 
192
 
193
  # Create a new UNet with the adjusted parameters
194
  unet_config = dict(unet.config)
195
+ filtered_unet_config = filter_dict(unet_config, FlaxUNet2DConditionModel.__init__)
196
+
197
+ print("Filtered UNet config keys:", filtered_unet_config.keys())
 
198
 
199
+ adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
200
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
201
  adjusted_params['params'] = float32_params['unet']
202