uruguayai commited on
Commit
b2ad618
·
verified ·
1 Parent(s): 1bbb97c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -124,21 +124,21 @@ if len(sample_batch['pixel_values']) > 0:
124
 
125
  # Training function
126
  def train_step(state, batch, rng):
127
- def compute_loss(params, pixel_values, rng):
128
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
129
  if pixel_values.ndim == 3:
130
  pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
131
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
132
 
133
- # Use the pipeline's VAE directly
134
  latents = pipeline.vae.apply(
135
- {"params": pipeline.params["vae"]},
136
  pixel_values,
137
  method=pipeline.vae.encode
138
  ).latent_dist.sample(rng)
139
  latents = latents * jnp.float32(0.18215)
140
  print(f"latents shape: {latents.shape}")
141
-
142
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
143
 
144
  timesteps = jax.random.randint(
@@ -164,7 +164,7 @@ def train_step(state, batch, rng):
164
 
165
  # Use the state's apply_fn (which should be the adjusted UNet)
166
  model_output = state.apply_fn(
167
- {'params': params},
168
  noisy_latents,
169
  jnp.array(timesteps, dtype=jnp.int32),
170
  encoder_hidden_states,
@@ -172,7 +172,7 @@ def train_step(state, batch, rng):
172
  ).sample
173
 
174
  return jnp.mean((model_output - noise) ** 2)
175
-
176
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
177
  rng, step_rng = jax.random.split(rng)
178
 
 
124
 
125
  # Training function
126
  def train_step(state, batch, rng):
127
+ def compute_loss(unet_params, pixel_values, rng):
128
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
129
  if pixel_values.ndim == 3:
130
  pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
131
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
132
 
133
+ # Use the VAE from params
134
  latents = pipeline.vae.apply(
135
+ {"params": params["vae"]},
136
  pixel_values,
137
  method=pipeline.vae.encode
138
  ).latent_dist.sample(rng)
139
  latents = latents * jnp.float32(0.18215)
140
  print(f"latents shape: {latents.shape}")
141
+
142
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
143
 
144
  timesteps = jax.random.randint(
 
164
 
165
  # Use the state's apply_fn (which should be the adjusted UNet)
166
  model_output = state.apply_fn(
167
+ {'params': unet_params},
168
  noisy_latents,
169
  jnp.array(timesteps, dtype=jnp.int32),
170
  encoder_hidden_states,
 
172
  ).sample
173
 
174
  return jnp.mean((model_output - noise) ** 2)
175
+
176
  grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
177
  rng, step_rng = jax.random.split(rng)
178