uruguayai commited on
Commit
cc5a61c
·
verified ·
1 Parent(s): 0d8b9ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -94
app.py CHANGED
@@ -53,12 +53,9 @@ def preprocess_images(examples):
53
  image = Image.open(image)
54
  if not isinstance(image, Image.Image):
55
  raise ValueError(f"Unexpected image type: {type(image)}")
56
- # Resize and convert to RGB
57
- image = image.convert("RGB").resize((512, 512))
58
- # Convert to numpy array and normalize
59
  image = np.array(image).astype(np.float16) / 255.0
60
- # Ensure the image has the shape (3, height, width)
61
- return image.transpose(2, 0, 1) # Change to channel-first format
62
 
63
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
64
 
@@ -76,15 +73,9 @@ try:
76
  processed_dataset = pickle.load(f)
77
  else:
78
  print("Loading dataset from Hugging Face...")
79
- dataset = load_dataset(dataset_name)
80
- print("Dataset structure:", dataset)
81
- print("Available splits:", dataset.keys())
82
-
83
- if "train" not in dataset:
84
- raise ValueError("The dataset does not contain a 'train' split.")
85
-
86
  print("Processing dataset...")
87
- processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
88
  with open(dataset_cache_file, 'wb') as f:
89
  pickle.dump(processed_dataset, f)
90
 
@@ -92,23 +83,7 @@ try:
92
 
93
  except Exception as e:
94
  print(f"Error loading or processing dataset: {str(e)}")
95
- print("Attempting to find dataset...")
96
-
97
- # List contents of current directory and parent directories
98
- print("Current directory contents:")
99
- print(os.listdir('.'))
100
- print("Parent directory contents:")
101
- print(os.listdir('..'))
102
- print("Root directory contents:")
103
- print(os.listdir('/'))
104
-
105
- # Try to find any directory that might contain the dataset
106
- for root, dirs, files in os.walk('/'):
107
- if 'montevideo' in dirs:
108
- print(f"Found 'montevideo' directory at: {os.path.join(root, 'montevideo')}")
109
- print(f"Contents: {os.listdir(os.path.join(root, 'montevideo'))}")
110
-
111
- raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
112
 
113
  # Function to clear JIT cache
114
  def clear_jit_cache():
@@ -116,122 +91,78 @@ def clear_jit_cache():
116
  gc.collect()
117
 
118
  # Training function with gradient accumulation
119
- def train_step(state, batch, rng, grad_accumulation_steps=8):
120
- def compute_loss(params, batch_slice, rng):
121
- # Convert batch slice to JAX array
122
- pixel_values = jnp.array(batch_slice["pixel_values"], dtype=jnp.float16)
123
- batch_size = pixel_values.shape[0]
124
-
125
- # Encode images to latent space
126
  latents = pipeline.vae.apply(
127
  {"params": params["vae"]},
128
  pixel_values,
129
  method=pipeline.vae.encode
130
  ).latent_dist.sample(rng)
131
- latents = latents * jnp.float16(0.18215) # scaling factor
132
 
133
- # Generate random noise
134
- noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
135
- noise = jax.random.normal(noise_rng, latents.shape, dtype=jnp.float16)
136
-
137
- # Sample random timesteps
138
  timesteps = jax.random.randint(
139
- timestep_rng, (batch_size,), 0, pipeline.scheduler.config.num_train_timesteps
140
  )
141
-
142
- # Create scheduler state
143
- scheduler_state = pipeline.scheduler.create_state()
144
-
145
- # Add noise to latents using the scheduler
146
  noisy_latents = pipeline.scheduler.add_noise(
147
- scheduler_state,
148
  original_samples=latents,
149
  noise=noise,
150
  timesteps=timesteps
151
  )
152
 
153
- # Generate random latents for text encoder
154
  encoder_hidden_states = jax.random.normal(
155
- latents_rng,
156
- (batch_size, pipeline.text_encoder.config.hidden_size),
157
  dtype=jnp.float16
158
  )
159
 
160
- # Predict noise
161
  model_output = state.apply_fn.apply(
162
  {'params': params["unet"]},
163
- jnp.array(noisy_latents, dtype=jnp.float16),
164
- jnp.array(timesteps, dtype=jnp.float16),
165
  encoder_hidden_states=encoder_hidden_states,
166
  train=True,
167
  )
168
 
169
- # Compute loss
170
- loss = jnp.mean((model_output - noise) ** 2)
171
- return loss
172
 
173
  grad_fn = jax.value_and_grad(compute_loss)
174
-
175
- # Split the batch into smaller chunks
176
- batch_size = len(batch['pixel_values'])
177
- chunk_size = batch_size // grad_accumulation_steps
178
-
179
- # Initialize accumulated gradients
180
- acc_grads = jax.tree_map(jnp.zeros_like, state.params)
181
- acc_loss = jnp.float16(0.0)
182
-
183
- for i in range(grad_accumulation_steps):
184
- start_idx = i * chunk_size
185
- end_idx = start_idx + chunk_size if i < grad_accumulation_steps - 1 else batch_size
186
-
187
- batch_slice = {
188
- 'pixel_values': batch['pixel_values'][start_idx:end_idx]
189
- }
190
-
191
- rng, step_rng = jax.random.split(rng)
192
- loss, grads = grad_fn(state.params, batch_slice, step_rng)
193
-
194
- # Accumulate gradients and loss
195
- acc_grads = jax.tree_map(lambda acc, g: acc + g / grad_accumulation_steps, acc_grads, grads)
196
- acc_loss += loss / grad_accumulation_steps
197
-
198
- # Update state with accumulated gradients
199
- state = state.apply_gradients(grads=acc_grads)
200
- return state, acc_loss
201
 
202
  # Initialize training state
203
  learning_rate = jnp.float16(1e-5)
204
  optimizer = optax.adam(learning_rate)
205
  state = train_state.TrainState.create(
206
  apply_fn=unet,
207
- params={"unet": params["unet"], "vae": params["vae"]}, # Include both UNet and VAE params
208
  tx=optimizer,
209
  )
210
 
211
  # Training loop
212
- num_epochs = 10
213
- batch_size = 2 # Reduced batch size
214
  rng = jax.random.PRNGKey(0)
215
 
216
  for epoch in range(num_epochs):
217
  epoch_loss = 0
218
  num_batches = 0
219
  for batch in tqdm(processed_dataset.batch(batch_size)):
220
- # Convert the list of pixel values to a numpy array for each batch
221
- batch['pixel_values'] = np.array(batch['pixel_values'], dtype=np.float16)
222
  rng, step_rng = jax.random.split(rng)
223
  state, loss = train_step(state, batch, step_rng)
224
  epoch_loss += loss
225
  num_batches += 1
226
 
227
- # Clear JIT cache every 10 batches
228
  if num_batches % 10 == 0:
229
  clear_jit_cache()
230
 
231
  avg_loss = epoch_loss / num_batches
232
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
233
-
234
- # Clear JIT cache after each epoch
235
  clear_jit_cache()
236
 
237
  # Save the fine-tuned model
 
53
  image = Image.open(image)
54
  if not isinstance(image, Image.Image):
55
  raise ValueError(f"Unexpected image type: {type(image)}")
56
+ image = image.convert("RGB").resize((256, 256)) # Reduced image size
 
 
57
  image = np.array(image).astype(np.float16) / 255.0
58
+ return image.transpose(2, 0, 1)
 
59
 
60
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
61
 
 
73
  processed_dataset = pickle.load(f)
74
  else:
75
  print("Loading dataset from Hugging Face...")
76
+ dataset = load_dataset(dataset_name, split="train[:1000]") # Load only first 1000 samples
 
 
 
 
 
 
77
  print("Processing dataset...")
78
+ processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
79
  with open(dataset_cache_file, 'wb') as f:
80
  pickle.dump(processed_dataset, f)
81
 
 
83
 
84
  except Exception as e:
85
  print(f"Error loading or processing dataset: {str(e)}")
86
+ raise ValueError("Unable to load or process the dataset.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Function to clear JIT cache
89
  def clear_jit_cache():
 
91
  gc.collect()
92
 
93
  # Training function with gradient accumulation
94
+ @jax.jit
95
+ def train_step(state, batch, rng):
96
+ def compute_loss(params, pixel_values, rng):
 
 
 
 
97
  latents = pipeline.vae.apply(
98
  {"params": params["vae"]},
99
  pixel_values,
100
  method=pipeline.vae.encode
101
  ).latent_dist.sample(rng)
102
+ latents = latents * jnp.float16(0.18215)
103
 
104
+ noise = jax.random.normal(rng, latents.shape, dtype=jnp.float16)
 
 
 
 
105
  timesteps = jax.random.randint(
106
+ rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
107
  )
 
 
 
 
 
108
  noisy_latents = pipeline.scheduler.add_noise(
109
+ pipeline.scheduler.create_state(),
110
  original_samples=latents,
111
  noise=noise,
112
  timesteps=timesteps
113
  )
114
 
 
115
  encoder_hidden_states = jax.random.normal(
116
+ rng,
117
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size),
118
  dtype=jnp.float16
119
  )
120
 
 
121
  model_output = state.apply_fn.apply(
122
  {'params': params["unet"]},
123
+ noisy_latents,
124
+ timesteps,
125
  encoder_hidden_states=encoder_hidden_states,
126
  train=True,
127
  )
128
 
129
+ return jnp.mean((model_output - noise) ** 2)
 
 
130
 
131
  grad_fn = jax.value_and_grad(compute_loss)
132
+ rng, step_rng = jax.random.split(rng)
133
+ loss, grads = grad_fn(state.params, batch["pixel_values"], step_rng)
134
+ state = state.apply_gradients(grads=grads)
135
+ return state, loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # Initialize training state
138
  learning_rate = jnp.float16(1e-5)
139
  optimizer = optax.adam(learning_rate)
140
  state = train_state.TrainState.create(
141
  apply_fn=unet,
142
+ params={"unet": params["unet"], "vae": params["vae"]},
143
  tx=optimizer,
144
  )
145
 
146
  # Training loop
147
+ num_epochs = 5 # Reduced number of epochs
148
+ batch_size = 4
149
  rng = jax.random.PRNGKey(0)
150
 
151
  for epoch in range(num_epochs):
152
  epoch_loss = 0
153
  num_batches = 0
154
  for batch in tqdm(processed_dataset.batch(batch_size)):
155
+ batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float16)
 
156
  rng, step_rng = jax.random.split(rng)
157
  state, loss = train_step(state, batch, step_rng)
158
  epoch_loss += loss
159
  num_batches += 1
160
 
 
161
  if num_batches % 10 == 0:
162
  clear_jit_cache()
163
 
164
  avg_loss = epoch_loss / num_batches
165
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
 
 
166
  clear_jit_cache()
167
 
168
  # Save the fine-tuned model