uruguayai commited on
Commit
4434e29
·
verified ·
1 Parent(s): 2cee4c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -57,7 +57,7 @@ def preprocess_images(examples):
57
  image = Image.open(image)
58
  if not isinstance(image, Image.Image):
59
  raise ValueError(f"Unexpected image type: {type(image)}")
60
- image = image.convert("RGB").resize((512, 512)) # Keep original size
61
  image = np.array(image).astype(np.float32) / 255.0
62
  return image.transpose(2, 0, 1)
63
 
@@ -214,7 +214,7 @@ for epoch in range(num_epochs):
214
  epoch_loss = 0
215
  num_batches = 0
216
  for batch in tqdm(processed_dataset.batch(batch_size)):
217
- batch['pixel_values'] = jnp.array(batch['pixel_values'])
218
  rng, step_rng = jax.random.split(rng)
219
  state, loss = train_step(state, batch, step_rng)
220
  epoch_loss += loss
 
57
  image = Image.open(image)
58
  if not isinstance(image, Image.Image):
59
  raise ValueError(f"Unexpected image type: {type(image)}")
60
+ image = image.convert("RGB").resize((512, 512))
61
  image = np.array(image).astype(np.float32) / 255.0
62
  return image.transpose(2, 0, 1)
63
 
 
214
  epoch_loss = 0
215
  num_batches = 0
216
  for batch in tqdm(processed_dataset.batch(batch_size)):
217
+ batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
218
  rng, step_rng = jax.random.split(rng)
219
  state, loss = train_step(state, batch, step_rng)
220
  epoch_loss += loss