Update app.py
Browse files
app.py
CHANGED
@@ -130,8 +130,9 @@ def train_step(state, batch, rng):
|
|
130 |
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
|
131 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
132 |
|
|
|
133 |
latents = pipeline.vae.apply(
|
134 |
-
{"params": params["vae"]},
|
135 |
pixel_values,
|
136 |
method=pipeline.vae.encode
|
137 |
).latent_dist.sample(rng)
|
@@ -161,9 +162,9 @@ def train_step(state, batch, rng):
|
|
161 |
print(f"timesteps shape: {timesteps.shape}")
|
162 |
print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
|
163 |
|
164 |
-
# Use the
|
165 |
model_output = state.apply_fn(
|
166 |
-
{'params': params
|
167 |
noisy_latents,
|
168 |
jnp.array(timesteps, dtype=jnp.int32),
|
169 |
encoder_hidden_states,
|
|
|
130 |
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
|
131 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
132 |
|
133 |
+
# Use the pipeline's VAE directly
|
134 |
latents = pipeline.vae.apply(
|
135 |
+
{"params": pipeline.params["vae"]},
|
136 |
pixel_values,
|
137 |
method=pipeline.vae.encode
|
138 |
).latent_dist.sample(rng)
|
|
|
162 |
print(f"timesteps shape: {timesteps.shape}")
|
163 |
print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
|
164 |
|
165 |
+
# Use the state's apply_fn (which should be the adjusted UNet)
|
166 |
model_output = state.apply_fn(
|
167 |
+
{'params': params},
|
168 |
noisy_latents,
|
169 |
jnp.array(timesteps, dtype=jnp.int32),
|
170 |
encoder_hidden_states,
|