uruguayai commited on
Commit
2cee4c3
·
verified ·
1 Parent(s): 8835824

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -4
app.py CHANGED
@@ -149,14 +149,65 @@ def train_step(state, batch, rng):
149
  learning_rate = 1e-5
150
  optimizer = optax.adam(learning_rate)
151
  state = train_state.TrainState.create(
152
- apply_fn=unet.__call__,
153
- params={"unet": params["unet"], "vae": params["vae"]},
154
  tx=optimizer,
155
  )
156
 
157
- # Training loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  num_epochs = 3
159
- batch_size = 1 # Reduced batch size due to memory constraints
160
  rng = jax.random.PRNGKey(0)
161
 
162
  for epoch in range(num_epochs):
 
149
  learning_rate = 1e-5
150
  optimizer = optax.adam(learning_rate)
151
  state = train_state.TrainState.create(
152
+ apply_fn=unet.__call__, # Use __call__ directly
153
+ params=params, # Pass all params
154
  tx=optimizer,
155
  )
156
 
157
+ # Modify the train_step function
158
+ def train_step(state, batch, rng):
159
+ def compute_loss(params, pixel_values, rng):
160
+ # Encode images to latent space
161
+ latents = pipeline.vae.apply(
162
+ {"params": params["vae"]},
163
+ pixel_values,
164
+ method=pipeline.vae.encode
165
+ ).latent_dist.sample(rng)
166
+ latents = latents * 0.18215
167
+
168
+ # Generate random noise
169
+ noise = jax.random.normal(rng, latents.shape)
170
+
171
+ # Sample random timesteps
172
+ timesteps = jax.random.randint(
173
+ rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
174
+ )
175
+
176
+ # Add noise to latents
177
+ noisy_latents = pipeline.scheduler.add_noise(
178
+ pipeline.scheduler.create_state(),
179
+ original_samples=latents,
180
+ noise=noise,
181
+ timesteps=timesteps
182
+ )
183
+
184
+ # Generate random encoder hidden states (simulating text embeddings)
185
+ encoder_hidden_states = jax.random.normal(
186
+ rng,
187
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size)
188
+ )
189
+
190
+ # Predict noise
191
+ model_output = state.apply_fn(
192
+ {'params': params["unet"]},
193
+ noisy_latents,
194
+ timesteps,
195
+ encoder_hidden_states=encoder_hidden_states,
196
+ train=True,
197
+ )
198
+
199
+ # Compute loss
200
+ return jnp.mean((model_output - noise) ** 2)
201
+
202
+ grad_fn = jax.value_and_grad(compute_loss)
203
+ rng, step_rng = jax.random.split(rng)
204
+ loss, grads = grad_fn(state.params, batch["pixel_values"], step_rng)
205
+ state = state.apply_gradients(grads=grads)
206
+ return state, loss
207
+
208
+ # Training loop (remains the same)
209
  num_epochs = 3
210
+ batch_size = 1
211
  rng = jax.random.PRNGKey(0)
212
 
213
  for epoch in range(num_epochs):