uruguayai commited on
Commit
6f034e3
·
verified ·
1 Parent(s): 157fd62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -1
app.py CHANGED
@@ -130,4 +130,86 @@ def train_step(state, batch, rng):
130
  pixel_values,
131
  method=pipeline.vae.encode
132
  ).latent_dist.sample(rng)
133
- latents = latents *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  pixel_values,
131
  method=pipeline.vae.encode
132
  ).latent_dist.sample(rng)
133
+ latents = latents * jnp.float32(0.18215)
134
+ print(f"latents shape: {latents.shape}")
135
+
136
+ noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
137
+
138
+ timesteps = jax.random.randint(
139
+ rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
140
+ )
141
+
142
+ noisy_latents = pipeline.scheduler.add_noise(
143
+ pipeline.scheduler.create_state(),
144
+ original_samples=latents,
145
+ noise=noise,
146
+ timesteps=timesteps
147
+ )
148
+
149
+ encoder_hidden_states = jax.random.normal(
150
+ rng,
151
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size),
152
+ dtype=jnp.float32
153
+ )
154
+
155
+ print(f"noisy_latents shape: {noisy_latents.shape}")
156
+ print(f"timesteps shape: {timesteps.shape}")
157
+ print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
158
+
159
+ # Use the correct method to call the UNet
160
+ model_output = unet.apply(
161
+ {'params': params["unet"]},
162
+ noisy_latents,
163
+ jnp.array(timesteps, dtype=jnp.int32),
164
+ encoder_hidden_states,
165
+ train=True,
166
+ ).sample
167
+
168
+ return jnp.mean((model_output - noise) ** 2)
169
+
170
+ grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
171
+ rng, step_rng = jax.random.split(rng)
172
+
173
+ grads = grad_fn(state.params, batch["pixel_values"], step_rng)
174
+ loss = compute_loss(state.params, batch["pixel_values"], step_rng)
175
+ state = state.apply_gradients(grads=grads)
176
+ return state, loss
177
+
178
+ # Initialize training state
179
+ learning_rate = 1e-5
180
+ optimizer = optax.adam(learning_rate)
181
+ float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
182
+ state = train_state.TrainState.create(
183
+ apply_fn=unet.apply,
184
+ params=float32_params,
185
+ tx=optimizer,
186
+ )
187
+
188
+ # Training loop
189
+ num_epochs = 3
190
+ batch_size = 1
191
+ rng = jax.random.PRNGKey(0)
192
+
193
+ for epoch in range(num_epochs):
194
+ epoch_loss = 0
195
+ num_batches = 0
196
+ for batch in tqdm(processed_dataset.batch(batch_size)):
197
+ batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
198
+ rng, step_rng = jax.random.split(rng)
199
+ state, loss = train_step(state, batch, step_rng)
200
+ epoch_loss += loss
201
+ num_batches += 1
202
+
203
+ if num_batches % 10 == 0:
204
+ jax.clear_caches()
205
+
206
+ avg_loss = epoch_loss / num_batches
207
+ print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
208
+ jax.clear_caches()
209
+
210
+ # Save the fine-tuned model
211
+ output_dir = "/tmp/montevideo_fine_tuned_model"
212
+ os.makedirs(output_dir, exist_ok=True)
213
+ unet.save_pretrained(output_dir, params=state.params["unet"])
214
+
215
+ print(f"Model saved to {output_dir}")