uruguayai commited on
Commit
5fadcb1
·
verified ·
1 Parent(s): 06b9137

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -55,13 +55,14 @@ def preprocess_images(examples):
55
  # Resize and convert to RGB
56
  image = image.convert("RGB").resize((512, 512))
57
  # Convert to numpy array and normalize
58
- image = np.array(image).astype(np.float32) / 127.5 - 1.0
59
  # Ensure the image has the shape (3, height, width)
60
  return image.transpose(2, 0, 1) # Change to channel-first format
61
 
62
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
63
 
64
 
 
65
  # Load dataset from Hugging Face
66
  dataset_name = "uruguayai/montevideo"
67
  dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
@@ -118,9 +119,13 @@ def train_step(state, batch, rng):
118
  pixel_values = jnp.array(batch["pixel_values"])
119
  batch_size = pixel_values.shape[0]
120
 
 
 
 
 
121
  # Generate random noise
122
  noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
123
- noise = jax.random.normal(noise_rng, pixel_values.shape)
124
 
125
  # Sample random timesteps
126
  timesteps = jax.random.randint(
@@ -130,23 +135,23 @@ def train_step(state, batch, rng):
130
  # Create scheduler state
131
  scheduler_state = pipeline.scheduler.create_state()
132
 
133
- # Add noise to images using the scheduler
134
- noisy_images = pipeline.scheduler.add_noise(
135
  scheduler_state,
136
- original_samples=pixel_values,
137
  noise=noise,
138
  timesteps=timesteps
139
  )
140
 
141
  # Generate random latents for text encoder
142
- latents = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size))
143
 
144
  # Predict noise
145
  model_output = state.apply_fn.apply(
146
  {'params': params},
147
- jnp.array(noisy_images),
148
  jnp.array(timesteps),
149
- encoder_hidden_states=latents,
150
  train=True,
151
  )
152
 
@@ -172,22 +177,6 @@ num_epochs = 10
172
  batch_size = 4
173
  rng = jax.random.PRNGKey(0)
174
 
175
- # Debug print
176
- print("Processed dataset info:")
177
- print(processed_dataset)
178
- print("First batch:")
179
- first_batch = next(iter(processed_dataset.batch(batch_size)))
180
- print(f"Batch keys: {first_batch.keys()}")
181
- print(f"Type of pixel_values: {type(first_batch['pixel_values'])}")
182
- if isinstance(first_batch['pixel_values'], list):
183
- print(f"Length of pixel_values list: {len(first_batch['pixel_values'])}")
184
- if len(first_batch['pixel_values']) > 0:
185
- print(f"Shape of first item in pixel_values: {np.array(first_batch['pixel_values'][0]).shape}")
186
-
187
- # Convert the list of pixel values to a numpy array
188
- first_batch['pixel_values'] = np.array(first_batch['pixel_values'])
189
- print(f"Pixel values shape after conversion: {first_batch['pixel_values'].shape}")
190
-
191
  for epoch in range(num_epochs):
192
  epoch_loss = 0
193
  num_batches = 0
@@ -202,7 +191,6 @@ for epoch in range(num_epochs):
202
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
203
 
204
 
205
-
206
  # Save the fine-tuned model
207
  output_dir = "/tmp/montevideo_fine_tuned_model"
208
  os.makedirs(output_dir, exist_ok=True)
 
55
  # Resize and convert to RGB
56
  image = image.convert("RGB").resize((512, 512))
57
  # Convert to numpy array and normalize
58
+ image = np.array(image).astype(np.float32) / 255.0
59
  # Ensure the image has the shape (3, height, width)
60
  return image.transpose(2, 0, 1) # Change to channel-first format
61
 
62
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
63
 
64
 
65
+
66
  # Load dataset from Hugging Face
67
  dataset_name = "uruguayai/montevideo"
68
  dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
 
119
  pixel_values = jnp.array(batch["pixel_values"])
120
  batch_size = pixel_values.shape[0]
121
 
122
+ # Encode images to latent space
123
+ latents = pipeline.vae.encode(pixel_values).latent_dist.sample(rng)
124
+ latents = latents * pipeline.vae.config.scaling_factor
125
+
126
  # Generate random noise
127
  noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
128
+ noise = jax.random.normal(noise_rng, latents.shape)
129
 
130
  # Sample random timesteps
131
  timesteps = jax.random.randint(
 
135
  # Create scheduler state
136
  scheduler_state = pipeline.scheduler.create_state()
137
 
138
+ # Add noise to latents using the scheduler
139
+ noisy_latents = pipeline.scheduler.add_noise(
140
  scheduler_state,
141
+ original_samples=latents,
142
  noise=noise,
143
  timesteps=timesteps
144
  )
145
 
146
  # Generate random latents for text encoder
147
+ encoder_hidden_states = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size))
148
 
149
  # Predict noise
150
  model_output = state.apply_fn.apply(
151
  {'params': params},
152
+ jnp.array(noisy_latents),
153
  jnp.array(timesteps),
154
+ encoder_hidden_states=encoder_hidden_states,
155
  train=True,
156
  )
157
 
 
177
  batch_size = 4
178
  rng = jax.random.PRNGKey(0)
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  for epoch in range(num_epochs):
181
  epoch_loss = 0
182
  num_batches = 0
 
191
  print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
192
 
193
 
 
194
  # Save the fine-tuned model
195
  output_dir = "/tmp/montevideo_fine_tuned_model"
196
  os.makedirs(output_dir, exist_ok=True)