Update app.py
Browse files
app.py
CHANGED
@@ -57,7 +57,7 @@ def preprocess_images(examples):
|
|
57 |
image = Image.open(image)
|
58 |
if not isinstance(image, Image.Image):
|
59 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
60 |
-
image = image.convert("RGB").resize((512, 512))
|
61 |
image = np.array(image).astype(np.float32) / 255.0
|
62 |
return image.transpose(2, 0, 1)
|
63 |
|
@@ -214,7 +214,7 @@ for epoch in range(num_epochs):
|
|
214 |
epoch_loss = 0
|
215 |
num_batches = 0
|
216 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
217 |
-
batch['pixel_values'] = jnp.array(batch['pixel_values'])
|
218 |
rng, step_rng = jax.random.split(rng)
|
219 |
state, loss = train_step(state, batch, step_rng)
|
220 |
epoch_loss += loss
|
|
|
57 |
image = Image.open(image)
|
58 |
if not isinstance(image, Image.Image):
|
59 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
60 |
+
image = image.convert("RGB").resize((512, 512))
|
61 |
image = np.array(image).astype(np.float32) / 255.0
|
62 |
return image.transpose(2, 0, 1)
|
63 |
|
|
|
214 |
epoch_loss = 0
|
215 |
num_batches = 0
|
216 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
217 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
|
218 |
rng, step_rng = jax.random.split(rng)
|
219 |
state, loss = train_step(state, batch, step_rng)
|
220 |
epoch_loss += loss
|