uruguayai commited on
Commit
9aba976
·
verified ·
1 Parent(s): dacaf33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -161,4 +161,28 @@ state = train_state.TrainState.create(
161
  # Training loop
162
  num_epochs = 3
163
  batch_size = 1
164
- rng = jax.random.PRNGKey(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # Training loop
162
  num_epochs = 3
163
  batch_size = 1
164
+ rng = jax.random.PRNGKey(0)
165
+
166
+ for epoch in range(num_epochs):
167
+ epoch_loss = 0
168
+ num_batches = 0
169
+ for batch in tqdm(processed_dataset.batch(batch_size)):
170
+ batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
171
+ rng, step_rng = jax.random.split(rng)
172
+ state, loss = train_step(state, batch, step_rng)
173
+ epoch_loss += loss
174
+ num_batches += 1
175
+
176
+ if num_batches % 10 == 0:
177
+ jax.clear_caches()
178
+
179
+ avg_loss = epoch_loss / num_batches
180
+ print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
181
+ jax.clear_caches()
182
+
183
+ # Save the fine-tuned model
184
+ output_dir = "/tmp/montevideo_fine_tuned_model"
185
+ os.makedirs(output_dir, exist_ok=True)
186
+ unet.save_pretrained(output_dir, params=state.params["unet"])
187
+
188
+ print(f"Model saved to {output_dir}")