uruguayai commited on
Commit
157fd62
·
verified ·
1 Parent(s): 60180ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -83
app.py CHANGED
@@ -122,6 +122,7 @@ if len(sample_batch['pixel_values']) > 0:
122
  def train_step(state, batch, rng):
123
  def compute_loss(params, pixel_values, rng):
124
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
 
125
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
126
 
127
  latents = pipeline.vae.apply(
@@ -129,86 +130,4 @@ def train_step(state, batch, rng):
129
  pixel_values,
130
  method=pipeline.vae.encode
131
  ).latent_dist.sample(rng)
132
- latents = latents * jnp.float32(0.18215)
133
- print(f"latents shape: {latents.shape}")
134
-
135
- noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
136
-
137
- timesteps = jax.random.randint(
138
- rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
139
- )
140
-
141
- noisy_latents = pipeline.scheduler.add_noise(
142
- pipeline.scheduler.create_state(),
143
- original_samples=latents,
144
- noise=noise,
145
- timesteps=timesteps
146
- )
147
-
148
- encoder_hidden_states = jax.random.normal(
149
- rng,
150
- (latents.shape[0], pipeline.text_encoder.config.hidden_size),
151
- dtype=jnp.float32
152
- )
153
-
154
- print(f"noisy_latents shape: {noisy_latents.shape}")
155
- print(f"timesteps shape: {timesteps.shape}")
156
- print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
157
-
158
- # Use the correct method to call the UNet
159
- model_output = unet.apply(
160
- {'params': params["unet"]},
161
- noisy_latents,
162
- jnp.array(timesteps, dtype=jnp.int32),
163
- encoder_hidden_states,
164
- train=True,
165
- ).sample
166
-
167
- return jnp.mean((model_output - noise) ** 2)
168
-
169
- grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
170
- rng, step_rng = jax.random.split(rng)
171
-
172
- grads = grad_fn(state.params, batch["pixel_values"], step_rng)
173
- loss = compute_loss(state.params, batch["pixel_values"], step_rng)
174
- state = state.apply_gradients(grads=grads)
175
- return state, loss
176
-
177
- # Initialize training state
178
- learning_rate = 1e-5
179
- optimizer = optax.adam(learning_rate)
180
- float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
181
- state = train_state.TrainState.create(
182
- apply_fn=unet.apply,
183
- params=float32_params,
184
- tx=optimizer,
185
- )
186
-
187
- # Training loop
188
- num_epochs = 3
189
- batch_size = 1
190
- rng = jax.random.PRNGKey(0)
191
-
192
- for epoch in range(num_epochs):
193
- epoch_loss = 0
194
- num_batches = 0
195
- for batch in tqdm(processed_dataset.batch(batch_size)):
196
- batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
197
- rng, step_rng = jax.random.split(rng)
198
- state, loss = train_step(state, batch, step_rng)
199
- epoch_loss += loss
200
- num_batches += 1
201
-
202
- if num_batches % 10 == 0:
203
- jax.clear_caches()
204
-
205
- avg_loss = epoch_loss / num_batches
206
- print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
207
- jax.clear_caches()
208
-
209
- # Save the fine-tuned model
210
- output_dir = "/tmp/montevideo_fine_tuned_model"
211
- os.makedirs(output_dir, exist_ok=True)
212
- unet.save_pretrained(output_dir, params=state.params["unet"])
213
-
214
- print(f"Model saved to {output_dir}")
 
122
  def train_step(state, batch, rng):
123
  def compute_loss(params, pixel_values, rng):
124
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
125
+ pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension
126
  print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
127
 
128
  latents = pipeline.vae.apply(
 
130
  pixel_values,
131
  method=pipeline.vae.encode
132
  ).latent_dist.sample(rng)
133
+ latents = latents *