Update app.py
Browse files
app.py
CHANGED
@@ -113,7 +113,7 @@ def train_step(state, batch, rng):
|
|
113 |
batch_size = pixel_values.shape[0]
|
114 |
|
115 |
# Generate random noise
|
116 |
-
noise_rng, timestep_rng = jax.random.split(rng)
|
117 |
noise = jax.random.normal(noise_rng, pixel_values.shape)
|
118 |
|
119 |
# Sample random timesteps
|
@@ -132,11 +132,15 @@ def train_step(state, batch, rng):
|
|
132 |
timesteps=timesteps
|
133 |
)
|
134 |
|
|
|
|
|
|
|
135 |
# Predict noise
|
136 |
model_output = state.apply_fn.apply(
|
137 |
{'params': params},
|
138 |
jnp.array(noisy_images),
|
139 |
jnp.array(timesteps),
|
|
|
140 |
train=True,
|
141 |
)
|
142 |
|
@@ -173,7 +177,6 @@ for epoch in range(num_epochs):
|
|
173 |
avg_loss = epoch_loss / num_batches
|
174 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
175 |
|
176 |
-
|
177 |
# Save the fine-tuned model
|
178 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
179 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
113 |
batch_size = pixel_values.shape[0]
|
114 |
|
115 |
# Generate random noise
|
116 |
+
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
117 |
noise = jax.random.normal(noise_rng, pixel_values.shape)
|
118 |
|
119 |
# Sample random timesteps
|
|
|
132 |
timesteps=timesteps
|
133 |
)
|
134 |
|
135 |
+
# Generate random latents for text encoder
|
136 |
+
latents = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size))
|
137 |
+
|
138 |
# Predict noise
|
139 |
model_output = state.apply_fn.apply(
|
140 |
{'params': params},
|
141 |
jnp.array(noisy_images),
|
142 |
jnp.array(timesteps),
|
143 |
+
encoder_hidden_states=latents,
|
144 |
train=True,
|
145 |
)
|
146 |
|
|
|
177 |
avg_loss = epoch_loss / num_batches
|
178 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
179 |
|
|
|
180 |
# Save the fine-tuned model
|
181 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
182 |
os.makedirs(output_dir, exist_ok=True)
|