uruguayai commited on
Commit
1bbb97c
·
verified ·
1 Parent(s): 8a49030

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -130,8 +130,9 @@ def train_step(state, batch, rng):
130
  pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
131
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
132
 
 
133
  latents = pipeline.vae.apply(
134
- {"params": params["vae"]},
135
  pixel_values,
136
  method=pipeline.vae.encode
137
  ).latent_dist.sample(rng)
@@ -161,9 +162,9 @@ def train_step(state, batch, rng):
161
  print(f"timesteps shape: {timesteps.shape}")
162
  print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
163
 
164
- # Use the correct method to call the UNet
165
  model_output = state.apply_fn(
166
- {'params': params["unet"]},
167
  noisy_latents,
168
  jnp.array(timesteps, dtype=jnp.int32),
169
  encoder_hidden_states,
 
130
  pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
131
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
132
 
133
+ # Use the pipeline's VAE directly
134
  latents = pipeline.vae.apply(
135
+ {"params": pipeline.params["vae"]},
136
  pixel_values,
137
  method=pipeline.vae.encode
138
  ).latent_dist.sample(rng)
 
162
  print(f"timesteps shape: {timesteps.shape}")
163
  print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
164
 
165
+ # Use the state's apply_fn (which should be the adjusted UNet)
166
  model_output = state.apply_fn(
167
+ {'params': params},
168
  noisy_latents,
169
  jnp.array(timesteps, dtype=jnp.int32),
170
  encoder_hidden_states,