uruguayai commited on
Commit
e0e747f
·
verified ·
1 Parent(s): ed67914

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -161,7 +161,7 @@ def train_step(state, batch, rng):
161
  print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
162
 
163
  # Use the correct method to call the UNet
164
- model_output = unet.apply(
165
  {'params': params["unet"]},
166
  noisy_latents,
167
  jnp.array(timesteps, dtype=jnp.int32),
@@ -183,9 +183,15 @@ def train_step(state, batch, rng):
183
  learning_rate = 1e-5
184
  optimizer = optax.adam(learning_rate)
185
  float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
 
 
 
 
 
 
186
  state = train_state.TrainState.create(
187
- apply_fn=unet.apply,
188
- params=float32_params,
189
  tx=optimizer,
190
  )
191
 
@@ -214,6 +220,6 @@ for epoch in range(num_epochs):
214
  # Save the fine-tuned model
215
  output_dir = "/tmp/montevideo_fine_tuned_model"
216
  os.makedirs(output_dir, exist_ok=True)
217
- unet.save_pretrained(output_dir, params=state.params["unet"])
218
 
219
  print(f"Model saved to {output_dir}")
 
161
  print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
162
 
163
  # Use the correct method to call the UNet
164
+ model_output = state.apply_fn(
165
  {'params': params["unet"]},
166
  noisy_latents,
167
  jnp.array(timesteps, dtype=jnp.int32),
 
183
  learning_rate = 1e-5
184
  optimizer = optax.adam(learning_rate)
185
  float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
186
+
187
+ # Create a new UNet with the adjusted parameters
188
+ adjusted_unet = FlaxUNet2DConditionModel(**unet.config)
189
+ adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
190
+ adjusted_params['params'] = float32_params['unet']
191
+
192
  state = train_state.TrainState.create(
193
+ apply_fn=adjusted_unet.apply,
194
+ params=adjusted_params,
195
  tx=optimizer,
196
  )
197
 
 
220
  # Save the fine-tuned model
221
  output_dir = "/tmp/montevideo_fine_tuned_model"
222
  os.makedirs(output_dir, exist_ok=True)
223
+ adjusted_unet.save_pretrained(output_dir, params=state.params["params"])
224
 
225
  print(f"Model saved to {output_dir}")