Update app.py
Browse files
app.py
CHANGED
@@ -56,7 +56,7 @@ def preprocess_images(examples):
|
|
56 |
image = image.convert("RGB").resize((512, 512))
|
57 |
# Convert to numpy array and normalize
|
58 |
image = np.array(image).astype(np.float32) / 127.5 - 1.0
|
59 |
-
# Ensure the image has the shape (height, width
|
60 |
return image.transpose(2, 0, 1) # Change to channel-first format
|
61 |
|
62 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
@@ -178,12 +178,22 @@ print(processed_dataset)
|
|
178 |
print("First batch:")
|
179 |
first_batch = next(iter(processed_dataset.batch(batch_size)))
|
180 |
print(f"Batch keys: {first_batch.keys()}")
|
181 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
for epoch in range(num_epochs):
|
184 |
epoch_loss = 0
|
185 |
num_batches = 0
|
186 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
|
|
|
|
187 |
rng, step_rng = jax.random.split(rng)
|
188 |
state, loss = train_step(state, batch, step_rng)
|
189 |
epoch_loss += loss
|
@@ -192,6 +202,7 @@ for epoch in range(num_epochs):
|
|
192 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
193 |
|
194 |
|
|
|
195 |
# Save the fine-tuned model
|
196 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
197 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
56 |
image = image.convert("RGB").resize((512, 512))
|
57 |
# Convert to numpy array and normalize
|
58 |
image = np.array(image).astype(np.float32) / 127.5 - 1.0
|
59 |
+
# Ensure the image has the shape (3, height, width)
|
60 |
return image.transpose(2, 0, 1) # Change to channel-first format
|
61 |
|
62 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
|
|
178 |
print("First batch:")
|
179 |
first_batch = next(iter(processed_dataset.batch(batch_size)))
|
180 |
print(f"Batch keys: {first_batch.keys()}")
|
181 |
+
print(f"Type of pixel_values: {type(first_batch['pixel_values'])}")
|
182 |
+
if isinstance(first_batch['pixel_values'], list):
|
183 |
+
print(f"Length of pixel_values list: {len(first_batch['pixel_values'])}")
|
184 |
+
if len(first_batch['pixel_values']) > 0:
|
185 |
+
print(f"Shape of first item in pixel_values: {np.array(first_batch['pixel_values'][0]).shape}")
|
186 |
+
|
187 |
+
# Convert the list of pixel values to a numpy array
|
188 |
+
first_batch['pixel_values'] = np.array(first_batch['pixel_values'])
|
189 |
+
print(f"Pixel values shape after conversion: {first_batch['pixel_values'].shape}")
|
190 |
|
191 |
for epoch in range(num_epochs):
|
192 |
epoch_loss = 0
|
193 |
num_batches = 0
|
194 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
195 |
+
# Convert the list of pixel values to a numpy array for each batch
|
196 |
+
batch['pixel_values'] = np.array(batch['pixel_values'])
|
197 |
rng, step_rng = jax.random.split(rng)
|
198 |
state, loss = train_step(state, batch, step_rng)
|
199 |
epoch_loss += loss
|
|
|
202 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
203 |
|
204 |
|
205 |
+
|
206 |
# Save the fine-tuned model
|
207 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
208 |
os.makedirs(output_dir, exist_ok=True)
|