Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,9 @@ import os
|
|
10 |
import pickle
|
11 |
from PIL import Image
|
12 |
import numpy as np
|
|
|
|
|
|
|
13 |
|
14 |
# Set up cache directories
|
15 |
cache_dir = "/tmp/huggingface_cache"
|
@@ -120,9 +123,13 @@ def train_step(state, batch, rng):
|
|
120 |
batch_size = pixel_values.shape[0]
|
121 |
|
122 |
# Encode images to latent space
|
123 |
-
latents = pipeline.vae.
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
126 |
# Generate random noise
|
127 |
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
128 |
noise = jax.random.normal(noise_rng, latents.shape)
|
@@ -148,7 +155,7 @@ def train_step(state, batch, rng):
|
|
148 |
|
149 |
# Predict noise
|
150 |
model_output = state.apply_fn.apply(
|
151 |
-
{'params': params},
|
152 |
jnp.array(noisy_latents),
|
153 |
jnp.array(timesteps),
|
154 |
encoder_hidden_states=encoder_hidden_states,
|
@@ -168,7 +175,7 @@ learning_rate = 1e-5
|
|
168 |
optimizer = optax.adam(learning_rate)
|
169 |
state = train_state.TrainState.create(
|
170 |
apply_fn=unet,
|
171 |
-
params=params["unet"], #
|
172 |
tx=optimizer,
|
173 |
)
|
174 |
|
@@ -191,6 +198,7 @@ for epoch in range(num_epochs):
|
|
191 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
192 |
|
193 |
|
|
|
194 |
# Save the fine-tuned model
|
195 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
196 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
10 |
import pickle
|
11 |
from PIL import Image
|
12 |
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
|
17 |
# Set up cache directories
|
18 |
cache_dir = "/tmp/huggingface_cache"
|
|
|
123 |
batch_size = pixel_values.shape[0]
|
124 |
|
125 |
# Encode images to latent space
|
126 |
+
latents = pipeline.vae.apply(
|
127 |
+
{"params": params["vae"]},
|
128 |
+
pixel_values,
|
129 |
+
method=pipeline.vae.encode
|
130 |
+
).latent_dist.sample(rng)
|
131 |
+
latents = latents * 0.18215 # scaling factor
|
132 |
+
|
133 |
# Generate random noise
|
134 |
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
135 |
noise = jax.random.normal(noise_rng, latents.shape)
|
|
|
155 |
|
156 |
# Predict noise
|
157 |
model_output = state.apply_fn.apply(
|
158 |
+
{'params': params["unet"]},
|
159 |
jnp.array(noisy_latents),
|
160 |
jnp.array(timesteps),
|
161 |
encoder_hidden_states=encoder_hidden_states,
|
|
|
175 |
optimizer = optax.adam(learning_rate)
|
176 |
state = train_state.TrainState.create(
|
177 |
apply_fn=unet,
|
178 |
+
params={"unet": params["unet"], "vae": params["vae"]}, # Include both UNet and VAE params
|
179 |
tx=optimizer,
|
180 |
)
|
181 |
|
|
|
198 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
199 |
|
200 |
|
201 |
+
|
202 |
# Save the fine-tuned model
|
203 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
204 |
os.makedirs(output_dir, exist_ok=True)
|