uruguayai commited on
Commit
dacaf33
·
verified ·
1 Parent(s): 4a48f70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -28
app.py CHANGED
@@ -129,13 +129,14 @@ def train_step(state, batch, rng):
129
  dtype=jnp.float32
130
  )
131
 
 
132
  model_output = unet.apply(
133
  {'params': params["unet"]},
134
- noisy_latents,
135
- timesteps,
136
  encoder_hidden_states,
137
  train=True,
138
- )
139
 
140
  return jnp.mean((model_output - noise) ** 2)
141
 
@@ -160,28 +161,4 @@ state = train_state.TrainState.create(
160
  # Training loop
161
  num_epochs = 3
162
  batch_size = 1
163
- rng = jax.random.PRNGKey(0)
164
-
165
- for epoch in range(num_epochs):
166
- epoch_loss = 0
167
- num_batches = 0
168
- for batch in tqdm(processed_dataset.batch(batch_size)):
169
- batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
170
- rng, step_rng = jax.random.split(rng)
171
- state, loss = train_step(state, batch, step_rng)
172
- epoch_loss += loss
173
- num_batches += 1
174
-
175
- if num_batches % 10 == 0:
176
- jax.clear_caches()
177
-
178
- avg_loss = epoch_loss / num_batches
179
- print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
180
- jax.clear_caches()
181
-
182
- # Save the fine-tuned model
183
- output_dir = "/tmp/montevideo_fine_tuned_model"
184
- os.makedirs(output_dir, exist_ok=True)
185
- unet.save_pretrained(output_dir, params=state.params["unet"])
186
-
187
- print(f"Model saved to {output_dir}")
 
129
  dtype=jnp.float32
130
  )
131
 
132
+ # Use the correct method to call the UNet
133
  model_output = unet.apply(
134
  {'params': params["unet"]},
135
+ jnp.array(noisy_latents),
136
+ jnp.array(timesteps, dtype=jnp.int32),
137
  encoder_hidden_states,
138
  train=True,
139
+ ).sample
140
 
141
  return jnp.mean((model_output - noise) ** 2)
142
 
 
161
  # Training loop
162
  num_epochs = 3
163
  batch_size = 1
164
+ rng = jax.random.PRNGKey(