uruguayai commited on
Commit
bd20ad9
·
verified ·
1 Parent(s): 35bc545

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -25
app.py CHANGED
@@ -10,9 +10,10 @@ import os
10
  import pickle
11
  from PIL import Image
12
  import numpy as np
13
-
14
-
15
 
 
 
16
 
17
  # Set up cache directories
18
  cache_dir = "/tmp/huggingface_cache"
@@ -35,7 +36,7 @@ def get_model(model_id, revision):
35
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
36
  model_id,
37
  revision=revision,
38
- dtype=jnp.float32,
39
  )
40
  with open(model_cache_file, 'wb') as f:
41
  pickle.dump((pipeline, params), f)
@@ -58,14 +59,12 @@ def preprocess_images(examples):
58
  # Resize and convert to RGB
59
  image = image.convert("RGB").resize((512, 512))
60
  # Convert to numpy array and normalize
61
- image = np.array(image).astype(np.float32) / 255.0
62
  # Ensure the image has the shape (3, height, width)
63
  return image.transpose(2, 0, 1) # Change to channel-first format
64
 
65
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
66
 
67
-
68
-
69
  # Load dataset from Hugging Face
70
  dataset_name = "uruguayai/montevideo"
71
  dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
@@ -114,12 +113,16 @@ except Exception as e:
114
 
115
  raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
116
 
 
 
 
 
117
 
118
- # Training function
119
- def train_step(state, batch, rng):
120
- def compute_loss(params):
121
- # Convert batch to JAX array
122
- pixel_values = jnp.array(batch["pixel_values"])
123
  batch_size = pixel_values.shape[0]
124
 
125
  # Encode images to latent space
@@ -128,11 +131,11 @@ def train_step(state, batch, rng):
128
  pixel_values,
129
  method=pipeline.vae.encode
130
  ).latent_dist.sample(rng)
131
- latents = latents * 0.18215 # scaling factor
132
 
133
  # Generate random noise
134
  noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
135
- noise = jax.random.normal(noise_rng, latents.shape)
136
 
137
  # Sample random timesteps
138
  timesteps = jax.random.randint(
@@ -151,13 +154,17 @@ def train_step(state, batch, rng):
151
  )
152
 
153
  # Generate random latents for text encoder
154
- encoder_hidden_states = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size))
 
 
 
 
155
 
156
  # Predict noise
157
  model_output = state.apply_fn.apply(
158
  {'params': params["unet"]},
159
- jnp.array(noisy_latents),
160
- jnp.array(timesteps),
161
  encoder_hidden_states=encoder_hidden_states,
162
  train=True,
163
  )
@@ -166,12 +173,37 @@ def train_step(state, batch, rng):
166
  loss = jnp.mean((model_output - noise) ** 2)
167
  return loss
168
 
169
- loss, grads = jax.value_and_grad(compute_loss)(state.params)
170
- state = state.apply_gradients(grads=grads)
171
- return state, loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  # Initialize training state
174
- learning_rate = 1e-5
175
  optimizer = optax.adam(learning_rate)
176
  state = train_state.TrainState.create(
177
  apply_fn=unet,
@@ -181,7 +213,7 @@ state = train_state.TrainState.create(
181
 
182
  # Training loop
183
  num_epochs = 10
184
- batch_size = 4
185
  rng = jax.random.PRNGKey(0)
186
 
187
  for epoch in range(num_epochs):
@@ -189,19 +221,25 @@ for epoch in range(num_epochs):
189
  num_batches = 0
190
  for batch in tqdm(processed_dataset.batch(batch_size)):
191
  # Convert the list of pixel values to a numpy array for each batch
192
- batch['pixel_values'] = np.array(batch['pixel_values'])
193
  rng, step_rng = jax.random.split(rng)
194
  state, loss = train_step(state, batch, step_rng)
195
  epoch_loss += loss
196
  num_batches += 1
 
 
 
 
 
197
  avg_loss = epoch_loss / num_batches
198
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
199
-
200
-
 
201
 
202
  # Save the fine-tuned model
203
  output_dir = "/tmp/montevideo_fine_tuned_model"
204
  os.makedirs(output_dir, exist_ok=True)
205
- unet.save_pretrained(output_dir, params=state.params)
206
 
207
  print(f"Model saved to {output_dir}")
 
10
  import pickle
11
  from PIL import Image
12
  import numpy as np
13
+ import gc
 
14
 
15
+ # Set default dtype to float16
16
+ jax.config.update("jax_default_dtype", "float16")
17
 
18
  # Set up cache directories
19
  cache_dir = "/tmp/huggingface_cache"
 
36
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
37
  model_id,
38
  revision=revision,
39
+ dtype=jnp.float16,
40
  )
41
  with open(model_cache_file, 'wb') as f:
42
  pickle.dump((pipeline, params), f)
 
59
  # Resize and convert to RGB
60
  image = image.convert("RGB").resize((512, 512))
61
  # Convert to numpy array and normalize
62
+ image = np.array(image).astype(np.float16) / 255.0
63
  # Ensure the image has the shape (3, height, width)
64
  return image.transpose(2, 0, 1) # Change to channel-first format
65
 
66
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
67
 
 
 
68
  # Load dataset from Hugging Face
69
  dataset_name = "uruguayai/montevideo"
70
  dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
 
113
 
114
  raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
115
 
116
+ # Function to clear JIT cache
117
+ def clear_jit_cache():
118
+ jax.clear_caches()
119
+ gc.collect()
120
 
121
+ # Training function with gradient accumulation
122
+ def train_step(state, batch, rng, grad_accumulation_steps=8):
123
+ def compute_loss(params, batch_slice, rng):
124
+ # Convert batch slice to JAX array
125
+ pixel_values = jnp.array(batch_slice["pixel_values"], dtype=jnp.float16)
126
  batch_size = pixel_values.shape[0]
127
 
128
  # Encode images to latent space
 
131
  pixel_values,
132
  method=pipeline.vae.encode
133
  ).latent_dist.sample(rng)
134
+ latents = latents * jnp.float16(0.18215) # scaling factor
135
 
136
  # Generate random noise
137
  noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
138
+ noise = jax.random.normal(noise_rng, latents.shape, dtype=jnp.float16)
139
 
140
  # Sample random timesteps
141
  timesteps = jax.random.randint(
 
154
  )
155
 
156
  # Generate random latents for text encoder
157
+ encoder_hidden_states = jax.random.normal(
158
+ latents_rng,
159
+ (batch_size, pipeline.text_encoder.config.hidden_size),
160
+ dtype=jnp.float16
161
+ )
162
 
163
  # Predict noise
164
  model_output = state.apply_fn.apply(
165
  {'params': params["unet"]},
166
+ jnp.array(noisy_latents, dtype=jnp.float16),
167
+ jnp.array(timesteps, dtype=jnp.float16),
168
  encoder_hidden_states=encoder_hidden_states,
169
  train=True,
170
  )
 
173
  loss = jnp.mean((model_output - noise) ** 2)
174
  return loss
175
 
176
+ grad_fn = jax.value_and_grad(compute_loss)
177
+
178
+ # Split the batch into smaller chunks
179
+ batch_size = len(batch['pixel_values'])
180
+ chunk_size = batch_size // grad_accumulation_steps
181
+
182
+ # Initialize accumulated gradients
183
+ acc_grads = jax.tree_map(jnp.zeros_like, state.params)
184
+ acc_loss = jnp.float16(0.0)
185
+
186
+ for i in range(grad_accumulation_steps):
187
+ start_idx = i * chunk_size
188
+ end_idx = start_idx + chunk_size if i < grad_accumulation_steps - 1 else batch_size
189
+
190
+ batch_slice = {
191
+ 'pixel_values': batch['pixel_values'][start_idx:end_idx]
192
+ }
193
+
194
+ rng, step_rng = jax.random.split(rng)
195
+ loss, grads = grad_fn(state.params, batch_slice, step_rng)
196
+
197
+ # Accumulate gradients and loss
198
+ acc_grads = jax.tree_map(lambda acc, g: acc + g / grad_accumulation_steps, acc_grads, grads)
199
+ acc_loss += loss / grad_accumulation_steps
200
+
201
+ # Update state with accumulated gradients
202
+ state = state.apply_gradients(grads=acc_grads)
203
+ return state, acc_loss
204
 
205
  # Initialize training state
206
+ learning_rate = jnp.float16(1e-5)
207
  optimizer = optax.adam(learning_rate)
208
  state = train_state.TrainState.create(
209
  apply_fn=unet,
 
213
 
214
  # Training loop
215
  num_epochs = 10
216
+ batch_size = 2 # Reduced batch size
217
  rng = jax.random.PRNGKey(0)
218
 
219
  for epoch in range(num_epochs):
 
221
  num_batches = 0
222
  for batch in tqdm(processed_dataset.batch(batch_size)):
223
  # Convert the list of pixel values to a numpy array for each batch
224
+ batch['pixel_values'] = np.array(batch['pixel_values'], dtype=np.float16)
225
  rng, step_rng = jax.random.split(rng)
226
  state, loss = train_step(state, batch, step_rng)
227
  epoch_loss += loss
228
  num_batches += 1
229
+
230
+ # Clear JIT cache every 10 batches
231
+ if num_batches % 10 == 0:
232
+ clear_jit_cache()
233
+
234
  avg_loss = epoch_loss / num_batches
235
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
236
+
237
+ # Clear JIT cache after each epoch
238
+ clear_jit_cache()
239
 
240
  # Save the fine-tuned model
241
  output_dir = "/tmp/montevideo_fine_tuned_model"
242
  os.makedirs(output_dir, exist_ok=True)
243
+ unet.save_pretrained(output_dir, params=state.params["unet"])
244
 
245
  print(f"Model saved to {output_dir}")