Update app.py
Browse files
app.py
CHANGED
@@ -124,21 +124,21 @@ if len(sample_batch['pixel_values']) > 0:
|
|
124 |
|
125 |
# Training function
|
126 |
def train_step(state, batch, rng):
|
127 |
-
def compute_loss(
|
128 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
129 |
if pixel_values.ndim == 3:
|
130 |
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
|
131 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
132 |
|
133 |
-
# Use the
|
134 |
latents = pipeline.vae.apply(
|
135 |
-
{"params":
|
136 |
pixel_values,
|
137 |
method=pipeline.vae.encode
|
138 |
).latent_dist.sample(rng)
|
139 |
latents = latents * jnp.float32(0.18215)
|
140 |
print(f"latents shape: {latents.shape}")
|
141 |
-
|
142 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
143 |
|
144 |
timesteps = jax.random.randint(
|
@@ -164,7 +164,7 @@ def train_step(state, batch, rng):
|
|
164 |
|
165 |
# Use the state's apply_fn (which should be the adjusted UNet)
|
166 |
model_output = state.apply_fn(
|
167 |
-
{'params':
|
168 |
noisy_latents,
|
169 |
jnp.array(timesteps, dtype=jnp.int32),
|
170 |
encoder_hidden_states,
|
@@ -172,7 +172,7 @@ def train_step(state, batch, rng):
|
|
172 |
).sample
|
173 |
|
174 |
return jnp.mean((model_output - noise) ** 2)
|
175 |
-
|
176 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
177 |
rng, step_rng = jax.random.split(rng)
|
178 |
|
|
|
124 |
|
125 |
# Training function
|
126 |
def train_step(state, batch, rng):
|
127 |
+
def compute_loss(unet_params, pixel_values, rng):
|
128 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
129 |
if pixel_values.ndim == 3:
|
130 |
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
|
131 |
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
|
132 |
|
133 |
+
# Use the VAE from params
|
134 |
latents = pipeline.vae.apply(
|
135 |
+
{"params": params["vae"]},
|
136 |
pixel_values,
|
137 |
method=pipeline.vae.encode
|
138 |
).latent_dist.sample(rng)
|
139 |
latents = latents * jnp.float32(0.18215)
|
140 |
print(f"latents shape: {latents.shape}")
|
141 |
+
|
142 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
143 |
|
144 |
timesteps = jax.random.randint(
|
|
|
164 |
|
165 |
# Use the state's apply_fn (which should be the adjusted UNet)
|
166 |
model_output = state.apply_fn(
|
167 |
+
{'params': unet_params},
|
168 |
noisy_latents,
|
169 |
jnp.array(timesteps, dtype=jnp.int32),
|
170 |
encoder_hidden_states,
|
|
|
172 |
).sample
|
173 |
|
174 |
return jnp.mean((model_output - noise) ** 2)
|
175 |
+
|
176 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
177 |
rng, step_rng = jax.random.split(rng)
|
178 |
|