Update app.py
Browse files
app.py
CHANGED
@@ -122,7 +122,8 @@ if len(sample_batch['pixel_values']) > 0:
|
|
122 |
def train_step(state, batch, rng):
|
123 |
def compute_loss(params, pixel_values, rng):
|
124 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
125 |
-
pixel_values
|
|
|
126 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
127 |
|
128 |
latents = pipeline.vae.apply(
|
|
|
122 |
def train_step(state, batch, rng):
|
123 |
def compute_loss(params, pixel_values, rng):
|
124 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
125 |
+
if pixel_values.ndim == 3:
|
126 |
+
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
|
127 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
128 |
|
129 |
latents = pipeline.vae.apply(
|