uruguayai commited on
Commit
60180ea
·
verified ·
1 Parent(s): 5f8640f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -83,6 +83,7 @@ def preprocess_images(examples):
83
  return None
84
  image = image.convert("RGB").resize((512, 512))
85
  image = np.array(image).astype(np.float32) / 255.0
 
86
  return image
87
 
88
  processed = [process_image(img) for img in examples["image"]]
@@ -121,7 +122,6 @@ if len(sample_batch['pixel_values']) > 0:
121
  def train_step(state, batch, rng):
122
  def compute_loss(params, pixel_values, rng):
123
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
124
- pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension
125
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
126
 
127
  latents = pipeline.vae.apply(
 
83
  return None
84
  image = image.convert("RGB").resize((512, 512))
85
  image = np.array(image).astype(np.float32) / 255.0
86
+ image = image.transpose(2, 0, 1) # Change to channel-first format
87
  return image
88
 
89
  processed = [process_image(img) for img in examples["image"]]
 
122
  def train_step(state, batch, rng):
123
  def compute_loss(params, pixel_values, rng):
124
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
 
125
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
126
 
127
  latents = pipeline.vae.apply(