Update app.py
Browse files
app.py
CHANGED
@@ -83,7 +83,7 @@ def preprocess_images(examples):
|
|
83 |
return None
|
84 |
image = image.convert("RGB").resize((512, 512))
|
85 |
image = np.array(image).astype(np.float32) / 255.0
|
86 |
-
return image
|
87 |
|
88 |
processed = [process_image(img) for img in examples["image"]]
|
89 |
return {"pixel_values": [img for img in processed if img is not None]}
|
@@ -121,6 +121,7 @@ if len(sample_batch['pixel_values']) > 0:
|
|
121 |
def train_step(state, batch, rng):
|
122 |
def compute_loss(params, pixel_values, rng):
|
123 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
|
|
124 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
125 |
|
126 |
latents = pipeline.vae.apply(
|
@@ -176,7 +177,7 @@ def train_step(state, batch, rng):
|
|
176 |
# Initialize training state
|
177 |
learning_rate = 1e-5
|
178 |
optimizer = optax.adam(learning_rate)
|
179 |
-
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
180 |
state = train_state.TrainState.create(
|
181 |
apply_fn=unet.apply,
|
182 |
params=float32_params,
|
|
|
83 |
return None
|
84 |
image = image.convert("RGB").resize((512, 512))
|
85 |
image = np.array(image).astype(np.float32) / 255.0
|
86 |
+
return image
|
87 |
|
88 |
processed = [process_image(img) for img in examples["image"]]
|
89 |
return {"pixel_values": [img for img in processed if img is not None]}
|
|
|
121 |
def train_step(state, batch, rng):
|
122 |
def compute_loss(params, pixel_values, rng):
|
123 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
124 |
+
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension
|
125 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
126 |
|
127 |
latents = pipeline.vae.apply(
|
|
|
177 |
# Initialize training state
|
178 |
learning_rate = 1e-5
|
179 |
optimizer = optax.adam(learning_rate)
|
180 |
+
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
181 |
state = train_state.TrainState.create(
|
182 |
apply_fn=unet.apply,
|
183 |
params=float32_params,
|