Update app.py
Browse files
app.py
CHANGED
@@ -161,4 +161,28 @@ state = train_state.TrainState.create(
|
|
161 |
# Training loop
|
162 |
num_epochs = 3
|
163 |
batch_size = 1
|
164 |
-
rng = jax.random.PRNGKey(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
# Training loop
|
162 |
num_epochs = 3
|
163 |
batch_size = 1
|
164 |
+
rng = jax.random.PRNGKey(0)
|
165 |
+
|
166 |
+
for epoch in range(num_epochs):
|
167 |
+
epoch_loss = 0
|
168 |
+
num_batches = 0
|
169 |
+
for batch in tqdm(processed_dataset.batch(batch_size)):
|
170 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
|
171 |
+
rng, step_rng = jax.random.split(rng)
|
172 |
+
state, loss = train_step(state, batch, step_rng)
|
173 |
+
epoch_loss += loss
|
174 |
+
num_batches += 1
|
175 |
+
|
176 |
+
if num_batches % 10 == 0:
|
177 |
+
jax.clear_caches()
|
178 |
+
|
179 |
+
avg_loss = epoch_loss / num_batches
|
180 |
+
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
181 |
+
jax.clear_caches()
|
182 |
+
|
183 |
+
# Save the fine-tuned model
|
184 |
+
output_dir = "/tmp/montevideo_fine_tuned_model"
|
185 |
+
os.makedirs(output_dir, exist_ok=True)
|
186 |
+
unet.save_pretrained(output_dir, params=state.params["unet"])
|
187 |
+
|
188 |
+
print(f"Model saved to {output_dir}")
|