uruguayai commited on
Commit
399bb13
·
verified ·
1 Parent(s): 6d5f395

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -238,26 +238,42 @@ num_epochs = 3
238
  batch_size = 1
239
  rng = jax.random.PRNGKey(0)
240
 
 
 
 
 
 
241
  for epoch in range(num_epochs):
242
  epoch_loss = 0
243
  num_batches = 0
 
244
  for batch in tqdm(processed_dataset.batch(batch_size)):
245
- batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
246
- rng, step_rng = jax.random.split(rng)
247
- state, loss = train_step(state, batch, step_rng)
248
- epoch_loss += loss
249
- num_batches += 1
250
-
251
- if num_batches % 10 == 0:
252
- jax.clear_caches()
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- avg_loss = epoch_loss / num_batches
255
- print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
256
  jax.clear_caches()
257
 
258
  # Save the fine-tuned model
259
  output_dir = "/tmp/montevideo_fine_tuned_model"
260
  os.makedirs(output_dir, exist_ok=True)
261
- adjusted_unet.save_pretrained(output_dir, params=state.params)
262
 
263
  print(f"Model saved to {output_dir}")
 
238
  batch_size = 1
239
  rng = jax.random.PRNGKey(0)
240
 
241
+ # Training loop
242
+ num_epochs = 3
243
+ batch_size = 1
244
+ rng = jax.random.PRNGKey(0)
245
+
246
  for epoch in range(num_epochs):
247
  epoch_loss = 0
248
  num_batches = 0
249
+ num_errors = 0
250
  for batch in tqdm(processed_dataset.batch(batch_size)):
251
+ try:
252
+ batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
253
+ rng, step_rng = jax.random.split(rng)
254
+ state, loss = train_step(state, batch, step_rng)
255
+ epoch_loss += loss
256
+ num_batches += 1
257
+
258
+ if num_batches % 10 == 0:
259
+ jax.clear_caches()
260
+ print(f"Processed {num_batches} batches. Current loss: {loss}")
261
+ except Exception as e:
262
+ num_errors += 1
263
+ print(f"Error processing batch: {e}")
264
+ continue
265
+
266
+ if num_batches > 0:
267
+ avg_loss = epoch_loss / num_batches
268
+ print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}, Errors: {num_errors}")
269
+ else:
270
+ print(f"Epoch {epoch+1}/{num_epochs}, No valid batches processed, Errors: {num_errors}")
271
 
 
 
272
  jax.clear_caches()
273
 
274
  # Save the fine-tuned model
275
  output_dir = "/tmp/montevideo_fine_tuned_model"
276
  os.makedirs(output_dir, exist_ok=True)
277
+ adjusted_unet.save_pretrained(output_dir, params=state.params["params"])
278
 
279
  print(f"Model saved to {output_dir}")