uruguayai commited on
Commit
06b9137
·
verified ·
1 Parent(s): 762155a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -56,7 +56,7 @@ def preprocess_images(examples):
56
  image = image.convert("RGB").resize((512, 512))
57
  # Convert to numpy array and normalize
58
  image = np.array(image).astype(np.float32) / 127.5 - 1.0
59
- # Ensure the image has the shape (height, width, 3)
60
  return image.transpose(2, 0, 1) # Change to channel-first format
61
 
62
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
@@ -178,12 +178,22 @@ print(processed_dataset)
178
  print("First batch:")
179
  first_batch = next(iter(processed_dataset.batch(batch_size)))
180
  print(f"Batch keys: {first_batch.keys()}")
181
- print(f"Pixel values shape: {first_batch['pixel_values'].shape}")
 
 
 
 
 
 
 
 
182
 
183
  for epoch in range(num_epochs):
184
  epoch_loss = 0
185
  num_batches = 0
186
  for batch in tqdm(processed_dataset.batch(batch_size)):
 
 
187
  rng, step_rng = jax.random.split(rng)
188
  state, loss = train_step(state, batch, step_rng)
189
  epoch_loss += loss
@@ -192,6 +202,7 @@ for epoch in range(num_epochs):
192
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
193
 
194
 
 
195
  # Save the fine-tuned model
196
  output_dir = "/tmp/montevideo_fine_tuned_model"
197
  os.makedirs(output_dir, exist_ok=True)
 
56
  image = image.convert("RGB").resize((512, 512))
57
  # Convert to numpy array and normalize
58
  image = np.array(image).astype(np.float32) / 127.5 - 1.0
59
+ # Ensure the image has the shape (3, height, width)
60
  return image.transpose(2, 0, 1) # Change to channel-first format
61
 
62
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
 
178
  print("First batch:")
179
  first_batch = next(iter(processed_dataset.batch(batch_size)))
180
  print(f"Batch keys: {first_batch.keys()}")
181
+ print(f"Type of pixel_values: {type(first_batch['pixel_values'])}")
182
+ if isinstance(first_batch['pixel_values'], list):
183
+ print(f"Length of pixel_values list: {len(first_batch['pixel_values'])}")
184
+ if len(first_batch['pixel_values']) > 0:
185
+ print(f"Shape of first item in pixel_values: {np.array(first_batch['pixel_values'][0]).shape}")
186
+
187
+ # Convert the list of pixel values to a numpy array
188
+ first_batch['pixel_values'] = np.array(first_batch['pixel_values'])
189
+ print(f"Pixel values shape after conversion: {first_batch['pixel_values'].shape}")
190
 
191
  for epoch in range(num_epochs):
192
  epoch_loss = 0
193
  num_batches = 0
194
  for batch in tqdm(processed_dataset.batch(batch_size)):
195
+ # Convert the list of pixel values to a numpy array for each batch
196
+ batch['pixel_values'] = np.array(batch['pixel_values'])
197
  rng, step_rng = jax.random.split(rng)
198
  state, loss = train_step(state, batch, step_rng)
199
  epoch_loss += loss
 
202
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
203
 
204
 
205
+
206
  # Save the fine-tuned model
207
  output_dir = "/tmp/montevideo_fine_tuned_model"
208
  os.makedirs(output_dir, exist_ok=True)