Update app.py
Browse files
app.py
CHANGED
@@ -97,18 +97,21 @@ def clear_jit_cache():
|
|
97 |
# Training function
|
98 |
def train_step(state, batch, rng):
|
99 |
def compute_loss(params, pixel_values, rng):
|
|
|
|
|
|
|
100 |
# Encode images to latent space
|
101 |
latents = pipeline.vae.apply(
|
102 |
{"params": params["vae"]},
|
103 |
pixel_values,
|
104 |
method=pipeline.vae.encode
|
105 |
).latent_dist.sample(rng)
|
106 |
-
latents = latents * 0.18215
|
107 |
|
108 |
# Generate random noise
|
109 |
-
noise = jax.random.normal(rng, latents.shape)
|
110 |
|
111 |
-
# Sample random timesteps
|
112 |
timesteps = jax.random.randint(
|
113 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
114 |
)
|
@@ -124,11 +127,12 @@ def train_step(state, batch, rng):
|
|
124 |
# Generate random encoder hidden states (simulating text embeddings)
|
125 |
encoder_hidden_states = jax.random.normal(
|
126 |
rng,
|
127 |
-
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
|
|
128 |
)
|
129 |
|
130 |
# Predict noise
|
131 |
-
model_output = state.apply_fn
|
132 |
{'params': params["unet"]},
|
133 |
noisy_latents,
|
134 |
timesteps,
|
@@ -139,9 +143,11 @@ def train_step(state, batch, rng):
|
|
139 |
# Compute loss
|
140 |
return jnp.mean((model_output - noise) ** 2)
|
141 |
|
142 |
-
grad_fn = jax.
|
143 |
rng, step_rng = jax.random.split(rng)
|
144 |
-
|
|
|
|
|
145 |
state = state.apply_gradients(grads=grads)
|
146 |
return state, loss
|
147 |
|
@@ -236,7 +242,6 @@ for epoch in range(num_epochs):
|
|
236 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
237 |
clear_jit_cache()
|
238 |
|
239 |
-
|
240 |
|
241 |
# Save the fine-tuned model
|
242 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
|
|
97 |
# Training function
|
98 |
def train_step(state, batch, rng):
|
99 |
def compute_loss(params, pixel_values, rng):
|
100 |
+
# Ensure pixel_values are float32
|
101 |
+
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
102 |
+
|
103 |
# Encode images to latent space
|
104 |
latents = pipeline.vae.apply(
|
105 |
{"params": params["vae"]},
|
106 |
pixel_values,
|
107 |
method=pipeline.vae.encode
|
108 |
).latent_dist.sample(rng)
|
109 |
+
latents = latents * jnp.float32(0.18215)
|
110 |
|
111 |
# Generate random noise
|
112 |
+
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
113 |
|
114 |
+
# Sample random timesteps (keep as integers)
|
115 |
timesteps = jax.random.randint(
|
116 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
117 |
)
|
|
|
127 |
# Generate random encoder hidden states (simulating text embeddings)
|
128 |
encoder_hidden_states = jax.random.normal(
|
129 |
rng,
|
130 |
+
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
131 |
+
dtype=jnp.float32
|
132 |
)
|
133 |
|
134 |
# Predict noise
|
135 |
+
model_output = state.apply_fn(
|
136 |
{'params': params["unet"]},
|
137 |
noisy_latents,
|
138 |
timesteps,
|
|
|
143 |
# Compute loss
|
144 |
return jnp.mean((model_output - noise) ** 2)
|
145 |
|
146 |
+
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
147 |
rng, step_rng = jax.random.split(rng)
|
148 |
+
|
149 |
+
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
150 |
+
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
|
151 |
state = state.apply_gradients(grads=grads)
|
152 |
return state, loss
|
153 |
|
|
|
242 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
243 |
clear_jit_cache()
|
244 |
|
|
|
245 |
|
246 |
# Save the fine-tuned model
|
247 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|