Update app.py
Browse files
app.py
CHANGED
@@ -129,13 +129,14 @@ def train_step(state, batch, rng):
|
|
129 |
dtype=jnp.float32
|
130 |
)
|
131 |
|
|
|
132 |
model_output = unet.apply(
|
133 |
{'params': params["unet"]},
|
134 |
-
noisy_latents,
|
135 |
-
timesteps,
|
136 |
encoder_hidden_states,
|
137 |
train=True,
|
138 |
-
)
|
139 |
|
140 |
return jnp.mean((model_output - noise) ** 2)
|
141 |
|
@@ -160,28 +161,4 @@ state = train_state.TrainState.create(
|
|
160 |
# Training loop
|
161 |
num_epochs = 3
|
162 |
batch_size = 1
|
163 |
-
rng = jax.random.PRNGKey(
|
164 |
-
|
165 |
-
for epoch in range(num_epochs):
|
166 |
-
epoch_loss = 0
|
167 |
-
num_batches = 0
|
168 |
-
for batch in tqdm(processed_dataset.batch(batch_size)):
|
169 |
-
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
|
170 |
-
rng, step_rng = jax.random.split(rng)
|
171 |
-
state, loss = train_step(state, batch, step_rng)
|
172 |
-
epoch_loss += loss
|
173 |
-
num_batches += 1
|
174 |
-
|
175 |
-
if num_batches % 10 == 0:
|
176 |
-
jax.clear_caches()
|
177 |
-
|
178 |
-
avg_loss = epoch_loss / num_batches
|
179 |
-
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
180 |
-
jax.clear_caches()
|
181 |
-
|
182 |
-
# Save the fine-tuned model
|
183 |
-
output_dir = "/tmp/montevideo_fine_tuned_model"
|
184 |
-
os.makedirs(output_dir, exist_ok=True)
|
185 |
-
unet.save_pretrained(output_dir, params=state.params["unet"])
|
186 |
-
|
187 |
-
print(f"Model saved to {output_dir}")
|
|
|
129 |
dtype=jnp.float32
|
130 |
)
|
131 |
|
132 |
+
# Use the correct method to call the UNet
|
133 |
model_output = unet.apply(
|
134 |
{'params': params["unet"]},
|
135 |
+
jnp.array(noisy_latents),
|
136 |
+
jnp.array(timesteps, dtype=jnp.int32),
|
137 |
encoder_hidden_states,
|
138 |
train=True,
|
139 |
+
).sample
|
140 |
|
141 |
return jnp.mean((model_output - noise) ** 2)
|
142 |
|
|
|
161 |
# Training loop
|
162 |
num_epochs = 3
|
163 |
batch_size = 1
|
164 |
+
rng = jax.random.PRNGKey(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|