Update app.py
Browse files
app.py
CHANGED
@@ -94,4 +94,102 @@ print(f"Processed dataset size: {len(processed_dataset)}")
|
|
94 |
def train_step(state, batch, rng):
|
95 |
def compute_loss(params, pixel_values, rng):
|
96 |
print("pixel_values dtype:", pixel_values.dtype)
|
97 |
-
print("params dtypes:", jax.tree_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def train_step(state, batch, rng):
|
95 |
def compute_loss(params, pixel_values, rng):
|
96 |
print("pixel_values dtype:", pixel_values.dtype)
|
97 |
+
print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
|
98 |
+
print("rng dtype:", rng.dtype)
|
99 |
+
|
100 |
+
# Ensure pixel_values are float32
|
101 |
+
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
102 |
+
|
103 |
+
# Encode images to latent space
|
104 |
+
latents = pipeline.vae.apply(
|
105 |
+
{"params": params["vae"]},
|
106 |
+
pixel_values,
|
107 |
+
method=pipeline.vae.encode
|
108 |
+
).latent_dist.sample(rng)
|
109 |
+
latents = latents * jnp.float32(0.18215)
|
110 |
+
|
111 |
+
# Generate random noise
|
112 |
+
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
113 |
+
|
114 |
+
# Sample random timesteps
|
115 |
+
timesteps = jax.random.randint(
|
116 |
+
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
117 |
+
)
|
118 |
+
|
119 |
+
print("timesteps dtype:", timesteps.dtype)
|
120 |
+
print("latents dtype:", latents.dtype)
|
121 |
+
print("noise dtype:", noise.dtype)
|
122 |
+
|
123 |
+
# Add noise to latents
|
124 |
+
noisy_latents = pipeline.scheduler.add_noise(
|
125 |
+
pipeline.scheduler.create_state(),
|
126 |
+
original_samples=latents,
|
127 |
+
noise=noise,
|
128 |
+
timesteps=timesteps
|
129 |
+
)
|
130 |
+
|
131 |
+
# Generate random encoder hidden states (simulating text embeddings)
|
132 |
+
encoder_hidden_states = jax.random.normal(
|
133 |
+
rng,
|
134 |
+
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
135 |
+
dtype=jnp.float32
|
136 |
+
)
|
137 |
+
|
138 |
+
# Predict noise
|
139 |
+
model_output = state.apply_fn(
|
140 |
+
{'params': params["unet"]},
|
141 |
+
noisy_latents,
|
142 |
+
timesteps,
|
143 |
+
encoder_hidden_states=encoder_hidden_states,
|
144 |
+
train=True,
|
145 |
+
)
|
146 |
+
|
147 |
+
# Compute loss
|
148 |
+
return jnp.mean((model_output - noise) ** 2)
|
149 |
+
|
150 |
+
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
151 |
+
rng, step_rng = jax.random.split(rng)
|
152 |
+
|
153 |
+
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
154 |
+
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
|
155 |
+
state = state.apply_gradients(grads=grads)
|
156 |
+
return state, loss
|
157 |
+
|
158 |
+
# Initialize training state
|
159 |
+
learning_rate = 1e-5
|
160 |
+
optimizer = optax.adam(learning_rate)
|
161 |
+
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
162 |
+
state = train_state.TrainState.create(
|
163 |
+
apply_fn=unet.__call__,
|
164 |
+
params=float32_params,
|
165 |
+
tx=optimizer,
|
166 |
+
)
|
167 |
+
|
168 |
+
# Training loop
|
169 |
+
num_epochs = 3
|
170 |
+
batch_size = 1
|
171 |
+
rng = jax.random.PRNGKey(0)
|
172 |
+
|
173 |
+
for epoch in range(num_epochs):
|
174 |
+
epoch_loss = 0
|
175 |
+
num_batches = 0
|
176 |
+
for batch in tqdm(processed_dataset.batch(batch_size)):
|
177 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
|
178 |
+
rng, step_rng = jax.random.split(rng)
|
179 |
+
state, loss = train_step(state, batch, step_rng)
|
180 |
+
epoch_loss += loss
|
181 |
+
num_batches += 1
|
182 |
+
|
183 |
+
if num_batches % 10 == 0:
|
184 |
+
jax.clear_caches()
|
185 |
+
|
186 |
+
avg_loss = epoch_loss / num_batches
|
187 |
+
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
188 |
+
jax.clear_caches()
|
189 |
+
|
190 |
+
# Save the fine-tuned model
|
191 |
+
output_dir = "/tmp/montevideo_fine_tuned_model"
|
192 |
+
os.makedirs(output_dir, exist_ok=True)
|
193 |
+
unet.save_pretrained(output_dir, params=state.params["unet"])
|
194 |
+
|
195 |
+
print(f"Model saved to {output_dir}")
|