Update app.py
Browse files
app.py
CHANGED
@@ -83,6 +83,7 @@ def preprocess_images(examples):
|
|
83 |
return None
|
84 |
image = image.convert("RGB").resize((512, 512))
|
85 |
image = np.array(image).astype(np.float32) / 255.0
|
|
|
86 |
return image
|
87 |
|
88 |
processed = [process_image(img) for img in examples["image"]]
|
@@ -121,7 +122,6 @@ if len(sample_batch['pixel_values']) > 0:
|
|
121 |
def train_step(state, batch, rng):
|
122 |
def compute_loss(params, pixel_values, rng):
|
123 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
124 |
-
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension
|
125 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
126 |
|
127 |
latents = pipeline.vae.apply(
|
|
|
83 |
return None
|
84 |
image = image.convert("RGB").resize((512, 512))
|
85 |
image = np.array(image).astype(np.float32) / 255.0
|
86 |
+
image = image.transpose(2, 0, 1) # Change to channel-first format
|
87 |
return image
|
88 |
|
89 |
processed = [process_image(img) for img in examples["image"]]
|
|
|
122 |
def train_step(state, batch, rng):
|
123 |
def compute_loss(params, pixel_values, rng):
|
124 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
|
|
125 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
126 |
|
127 |
latents = pipeline.vae.apply(
|