Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import os
|
|
11 |
import pickle
|
12 |
from PIL import Image
|
13 |
import numpy as np
|
|
|
14 |
|
15 |
# Custom Scheduler
|
16 |
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
|
@@ -179,6 +180,11 @@ def train_step(state, batch, rng):
|
|
179 |
state = state.apply_gradients(grads=grads)
|
180 |
return state, loss
|
181 |
|
|
|
|
|
|
|
|
|
|
|
182 |
# Initialize training state
|
183 |
learning_rate = 1e-5
|
184 |
optimizer = optax.adam(learning_rate)
|
@@ -186,12 +192,11 @@ float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dty
|
|
186 |
|
187 |
# Create a new UNet with the adjusted parameters
|
188 |
unet_config = dict(unet.config)
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
unet_config.pop(key, None)
|
193 |
|
194 |
-
adjusted_unet = FlaxUNet2DConditionModel(**
|
195 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
196 |
adjusted_params['params'] = float32_params['unet']
|
197 |
|
|
|
11 |
import pickle
|
12 |
from PIL import Image
|
13 |
import numpy as np
|
14 |
+
from inspect import signature
|
15 |
|
16 |
# Custom Scheduler
|
17 |
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
|
|
|
180 |
state = state.apply_gradients(grads=grads)
|
181 |
return state, loss
|
182 |
|
183 |
+
def filter_dict(dict_to_filter, target_callable):
|
184 |
+
"""Filter a dictionary to only include keys that are valid parameters for the target callable."""
|
185 |
+
valid_params = signature(target_callable).parameters.keys()
|
186 |
+
return {k: v for k, v in dict_to_filter.items() if k in valid_params}
|
187 |
+
|
188 |
# Initialize training state
|
189 |
learning_rate = 1e-5
|
190 |
optimizer = optax.adam(learning_rate)
|
|
|
192 |
|
193 |
# Create a new UNet with the adjusted parameters
|
194 |
unet_config = dict(unet.config)
|
195 |
+
filtered_unet_config = filter_dict(unet_config, FlaxUNet2DConditionModel.__init__)
|
196 |
+
|
197 |
+
print("Filtered UNet config keys:", filtered_unet_config.keys())
|
|
|
198 |
|
199 |
+
adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
|
200 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
201 |
adjusted_params['params'] = float32_params['unet']
|
202 |
|