Update app.py
Browse files
app.py
CHANGED
@@ -155,23 +155,27 @@ state = train_state.TrainState.create(
|
|
155 |
)
|
156 |
|
157 |
# Modify the train_step function
|
158 |
-
|
159 |
def compute_loss(params, pixel_values, rng):
|
|
|
|
|
|
|
160 |
# Encode images to latent space
|
161 |
latents = pipeline.vae.apply(
|
162 |
{"params": params["vae"]},
|
163 |
pixel_values,
|
164 |
method=pipeline.vae.encode
|
165 |
).latent_dist.sample(rng)
|
166 |
-
latents = latents * 0.18215
|
167 |
|
168 |
# Generate random noise
|
169 |
-
noise = jax.random.normal(rng, latents.shape)
|
170 |
|
171 |
# Sample random timesteps
|
172 |
timesteps = jax.random.randint(
|
173 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
174 |
)
|
|
|
175 |
|
176 |
# Add noise to latents
|
177 |
noisy_latents = pipeline.scheduler.add_noise(
|
@@ -184,7 +188,8 @@ def train_step(state, batch, rng):
|
|
184 |
# Generate random encoder hidden states (simulating text embeddings)
|
185 |
encoder_hidden_states = jax.random.normal(
|
186 |
rng,
|
187 |
-
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
|
|
188 |
)
|
189 |
|
190 |
# Predict noise
|
@@ -199,12 +204,16 @@ def train_step(state, batch, rng):
|
|
199 |
# Compute loss
|
200 |
return jnp.mean((model_output - noise) ** 2)
|
201 |
|
202 |
-
grad_fn = jax.
|
203 |
rng, step_rng = jax.random.split(rng)
|
204 |
-
|
|
|
|
|
205 |
state = state.apply_gradients(grads=grads)
|
206 |
return state, loss
|
207 |
|
|
|
|
|
208 |
# Training loop (remains the same)
|
209 |
num_epochs = 3
|
210 |
batch_size = 1
|
|
|
155 |
)
|
156 |
|
157 |
# Modify the train_step function
|
158 |
+
ef train_step(state, batch, rng):
|
159 |
def compute_loss(params, pixel_values, rng):
|
160 |
+
# Ensure pixel_values are float32
|
161 |
+
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
162 |
+
|
163 |
# Encode images to latent space
|
164 |
latents = pipeline.vae.apply(
|
165 |
{"params": params["vae"]},
|
166 |
pixel_values,
|
167 |
method=pipeline.vae.encode
|
168 |
).latent_dist.sample(rng)
|
169 |
+
latents = latents * jnp.float32(0.18215)
|
170 |
|
171 |
# Generate random noise
|
172 |
+
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
173 |
|
174 |
# Sample random timesteps
|
175 |
timesteps = jax.random.randint(
|
176 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
177 |
)
|
178 |
+
timesteps = jnp.array(timesteps, dtype=jnp.float32)
|
179 |
|
180 |
# Add noise to latents
|
181 |
noisy_latents = pipeline.scheduler.add_noise(
|
|
|
188 |
# Generate random encoder hidden states (simulating text embeddings)
|
189 |
encoder_hidden_states = jax.random.normal(
|
190 |
rng,
|
191 |
+
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
192 |
+
dtype=jnp.float32
|
193 |
)
|
194 |
|
195 |
# Predict noise
|
|
|
204 |
# Compute loss
|
205 |
return jnp.mean((model_output - noise) ** 2)
|
206 |
|
207 |
+
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
208 |
rng, step_rng = jax.random.split(rng)
|
209 |
+
|
210 |
+
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
211 |
+
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
|
212 |
state = state.apply_gradients(grads=grads)
|
213 |
return state, loss
|
214 |
|
215 |
+
|
216 |
+
|
217 |
# Training loop (remains the same)
|
218 |
num_epochs = 3
|
219 |
batch_size = 1
|