uruguayai commited on
Commit
8e214b7
·
verified ·
1 Parent(s): 3518b5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -1
app.py CHANGED
@@ -94,4 +94,102 @@ print(f"Processed dataset size: {len(processed_dataset)}")
94
  def train_step(state, batch, rng):
95
  def compute_loss(params, pixel_values, rng):
96
  print("pixel_values dtype:", pixel_values.dtype)
97
- print("params dtypes:", jax.tree_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def train_step(state, batch, rng):
95
  def compute_loss(params, pixel_values, rng):
96
  print("pixel_values dtype:", pixel_values.dtype)
97
+ print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
98
+ print("rng dtype:", rng.dtype)
99
+
100
+ # Ensure pixel_values are float32
101
+ pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
102
+
103
+ # Encode images to latent space
104
+ latents = pipeline.vae.apply(
105
+ {"params": params["vae"]},
106
+ pixel_values,
107
+ method=pipeline.vae.encode
108
+ ).latent_dist.sample(rng)
109
+ latents = latents * jnp.float32(0.18215)
110
+
111
+ # Generate random noise
112
+ noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
113
+
114
+ # Sample random timesteps
115
+ timesteps = jax.random.randint(
116
+ rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
117
+ )
118
+
119
+ print("timesteps dtype:", timesteps.dtype)
120
+ print("latents dtype:", latents.dtype)
121
+ print("noise dtype:", noise.dtype)
122
+
123
+ # Add noise to latents
124
+ noisy_latents = pipeline.scheduler.add_noise(
125
+ pipeline.scheduler.create_state(),
126
+ original_samples=latents,
127
+ noise=noise,
128
+ timesteps=timesteps
129
+ )
130
+
131
+ # Generate random encoder hidden states (simulating text embeddings)
132
+ encoder_hidden_states = jax.random.normal(
133
+ rng,
134
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size),
135
+ dtype=jnp.float32
136
+ )
137
+
138
+ # Predict noise
139
+ model_output = state.apply_fn(
140
+ {'params': params["unet"]},
141
+ noisy_latents,
142
+ timesteps,
143
+ encoder_hidden_states=encoder_hidden_states,
144
+ train=True,
145
+ )
146
+
147
+ # Compute loss
148
+ return jnp.mean((model_output - noise) ** 2)
149
+
150
+ grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
151
+ rng, step_rng = jax.random.split(rng)
152
+
153
+ grads = grad_fn(state.params, batch["pixel_values"], step_rng)
154
+ loss = compute_loss(state.params, batch["pixel_values"], step_rng)
155
+ state = state.apply_gradients(grads=grads)
156
+ return state, loss
157
+
158
+ # Initialize training state
159
+ learning_rate = 1e-5
160
+ optimizer = optax.adam(learning_rate)
161
+ float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
162
+ state = train_state.TrainState.create(
163
+ apply_fn=unet.__call__,
164
+ params=float32_params,
165
+ tx=optimizer,
166
+ )
167
+
168
+ # Training loop
169
+ num_epochs = 3
170
+ batch_size = 1
171
+ rng = jax.random.PRNGKey(0)
172
+
173
+ for epoch in range(num_epochs):
174
+ epoch_loss = 0
175
+ num_batches = 0
176
+ for batch in tqdm(processed_dataset.batch(batch_size)):
177
+ batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
178
+ rng, step_rng = jax.random.split(rng)
179
+ state, loss = train_step(state, batch, step_rng)
180
+ epoch_loss += loss
181
+ num_batches += 1
182
+
183
+ if num_batches % 10 == 0:
184
+ jax.clear_caches()
185
+
186
+ avg_loss = epoch_loss / num_batches
187
+ print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
188
+ jax.clear_caches()
189
+
190
+ # Save the fine-tuned model
191
+ output_dir = "/tmp/montevideo_fine_tuned_model"
192
+ os.makedirs(output_dir, exist_ok=True)
193
+ unet.save_pretrained(output_dir, params=state.params["unet"])
194
+
195
+ print(f"Model saved to {output_dir}")