Update app.py
Browse files
app.py
CHANGED
@@ -238,26 +238,42 @@ num_epochs = 3
|
|
238 |
batch_size = 1
|
239 |
rng = jax.random.PRNGKey(0)
|
240 |
|
|
|
|
|
|
|
|
|
|
|
241 |
for epoch in range(num_epochs):
|
242 |
epoch_loss = 0
|
243 |
num_batches = 0
|
|
|
244 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
-
avg_loss = epoch_loss / num_batches
|
255 |
-
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
256 |
jax.clear_caches()
|
257 |
|
258 |
# Save the fine-tuned model
|
259 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
260 |
os.makedirs(output_dir, exist_ok=True)
|
261 |
-
adjusted_unet.save_pretrained(output_dir, params=state.params)
|
262 |
|
263 |
print(f"Model saved to {output_dir}")
|
|
|
238 |
batch_size = 1
|
239 |
rng = jax.random.PRNGKey(0)
|
240 |
|
241 |
+
# Training loop
|
242 |
+
num_epochs = 3
|
243 |
+
batch_size = 1
|
244 |
+
rng = jax.random.PRNGKey(0)
|
245 |
+
|
246 |
for epoch in range(num_epochs):
|
247 |
epoch_loss = 0
|
248 |
num_batches = 0
|
249 |
+
num_errors = 0
|
250 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
251 |
+
try:
|
252 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
|
253 |
+
rng, step_rng = jax.random.split(rng)
|
254 |
+
state, loss = train_step(state, batch, step_rng)
|
255 |
+
epoch_loss += loss
|
256 |
+
num_batches += 1
|
257 |
+
|
258 |
+
if num_batches % 10 == 0:
|
259 |
+
jax.clear_caches()
|
260 |
+
print(f"Processed {num_batches} batches. Current loss: {loss}")
|
261 |
+
except Exception as e:
|
262 |
+
num_errors += 1
|
263 |
+
print(f"Error processing batch: {e}")
|
264 |
+
continue
|
265 |
+
|
266 |
+
if num_batches > 0:
|
267 |
+
avg_loss = epoch_loss / num_batches
|
268 |
+
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}, Errors: {num_errors}")
|
269 |
+
else:
|
270 |
+
print(f"Epoch {epoch+1}/{num_epochs}, No valid batches processed, Errors: {num_errors}")
|
271 |
|
|
|
|
|
272 |
jax.clear_caches()
|
273 |
|
274 |
# Save the fine-tuned model
|
275 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
276 |
os.makedirs(output_dir, exist_ok=True)
|
277 |
+
adjusted_unet.save_pretrained(output_dir, params=state.params["params"])
|
278 |
|
279 |
print(f"Model saved to {output_dir}")
|