Update app.py
Browse files
app.py
CHANGED
@@ -161,7 +161,7 @@ def train_step(state, batch, rng):
|
|
161 |
print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
|
162 |
|
163 |
# Use the correct method to call the UNet
|
164 |
-
model_output =
|
165 |
{'params': params["unet"]},
|
166 |
noisy_latents,
|
167 |
jnp.array(timesteps, dtype=jnp.int32),
|
@@ -183,9 +183,15 @@ def train_step(state, batch, rng):
|
|
183 |
learning_rate = 1e-5
|
184 |
optimizer = optax.adam(learning_rate)
|
185 |
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
state = train_state.TrainState.create(
|
187 |
-
apply_fn=
|
188 |
-
params=
|
189 |
tx=optimizer,
|
190 |
)
|
191 |
|
@@ -214,6 +220,6 @@ for epoch in range(num_epochs):
|
|
214 |
# Save the fine-tuned model
|
215 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
216 |
os.makedirs(output_dir, exist_ok=True)
|
217 |
-
|
218 |
|
219 |
print(f"Model saved to {output_dir}")
|
|
|
161 |
print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
|
162 |
|
163 |
# Use the correct method to call the UNet
|
164 |
+
model_output = state.apply_fn(
|
165 |
{'params': params["unet"]},
|
166 |
noisy_latents,
|
167 |
jnp.array(timesteps, dtype=jnp.int32),
|
|
|
183 |
learning_rate = 1e-5
|
184 |
optimizer = optax.adam(learning_rate)
|
185 |
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
186 |
+
|
187 |
+
# Create a new UNet with the adjusted parameters
|
188 |
+
adjusted_unet = FlaxUNet2DConditionModel(**unet.config)
|
189 |
+
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
190 |
+
adjusted_params['params'] = float32_params['unet']
|
191 |
+
|
192 |
state = train_state.TrainState.create(
|
193 |
+
apply_fn=adjusted_unet.apply,
|
194 |
+
params=adjusted_params,
|
195 |
tx=optimizer,
|
196 |
)
|
197 |
|
|
|
220 |
# Save the fine-tuned model
|
221 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
222 |
os.makedirs(output_dir, exist_ok=True)
|
223 |
+
adjusted_unet.save_pretrained(output_dir, params=state.params["params"])
|
224 |
|
225 |
print(f"Model saved to {output_dir}")
|