uruguayai commited on
Commit
5f8640f
·
verified ·
1 Parent(s): ad3bf4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -83,7 +83,7 @@ def preprocess_images(examples):
83
  return None
84
  image = image.convert("RGB").resize((512, 512))
85
  image = np.array(image).astype(np.float32) / 255.0
86
- return image.transpose(2, 0, 1)
87
 
88
  processed = [process_image(img) for img in examples["image"]]
89
  return {"pixel_values": [img for img in processed if img is not None]}
@@ -121,6 +121,7 @@ if len(sample_batch['pixel_values']) > 0:
121
  def train_step(state, batch, rng):
122
  def compute_loss(params, pixel_values, rng):
123
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
 
124
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
125
 
126
  latents = pipeline.vae.apply(
@@ -176,7 +177,7 @@ def train_step(state, batch, rng):
176
  # Initialize training state
177
  learning_rate = 1e-5
178
  optimizer = optax.adam(learning_rate)
179
- float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
180
  state = train_state.TrainState.create(
181
  apply_fn=unet.apply,
182
  params=float32_params,
 
83
  return None
84
  image = image.convert("RGB").resize((512, 512))
85
  image = np.array(image).astype(np.float32) / 255.0
86
+ return image
87
 
88
  processed = [process_image(img) for img in examples["image"]]
89
  return {"pixel_values": [img for img in processed if img is not None]}
 
121
  def train_step(state, batch, rng):
122
  def compute_loss(params, pixel_values, rng):
123
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
124
+ pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension
125
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
126
 
127
  latents = pipeline.vae.apply(
 
177
  # Initialize training state
178
  learning_rate = 1e-5
179
  optimizer = optax.adam(learning_rate)
180
+ float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
181
  state = train_state.TrainState.create(
182
  apply_fn=unet.apply,
183
  params=float32_params,