Update app.py
Browse files
app.py
CHANGED
@@ -55,13 +55,14 @@ def preprocess_images(examples):
|
|
55 |
# Resize and convert to RGB
|
56 |
image = image.convert("RGB").resize((512, 512))
|
57 |
# Convert to numpy array and normalize
|
58 |
-
image = np.array(image).astype(np.float32) /
|
59 |
# Ensure the image has the shape (3, height, width)
|
60 |
return image.transpose(2, 0, 1) # Change to channel-first format
|
61 |
|
62 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
63 |
|
64 |
|
|
|
65 |
# Load dataset from Hugging Face
|
66 |
dataset_name = "uruguayai/montevideo"
|
67 |
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
|
@@ -118,9 +119,13 @@ def train_step(state, batch, rng):
|
|
118 |
pixel_values = jnp.array(batch["pixel_values"])
|
119 |
batch_size = pixel_values.shape[0]
|
120 |
|
|
|
|
|
|
|
|
|
121 |
# Generate random noise
|
122 |
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
123 |
-
noise = jax.random.normal(noise_rng,
|
124 |
|
125 |
# Sample random timesteps
|
126 |
timesteps = jax.random.randint(
|
@@ -130,23 +135,23 @@ def train_step(state, batch, rng):
|
|
130 |
# Create scheduler state
|
131 |
scheduler_state = pipeline.scheduler.create_state()
|
132 |
|
133 |
-
# Add noise to
|
134 |
-
|
135 |
scheduler_state,
|
136 |
-
original_samples=
|
137 |
noise=noise,
|
138 |
timesteps=timesteps
|
139 |
)
|
140 |
|
141 |
# Generate random latents for text encoder
|
142 |
-
|
143 |
|
144 |
# Predict noise
|
145 |
model_output = state.apply_fn.apply(
|
146 |
{'params': params},
|
147 |
-
jnp.array(
|
148 |
jnp.array(timesteps),
|
149 |
-
encoder_hidden_states=
|
150 |
train=True,
|
151 |
)
|
152 |
|
@@ -172,22 +177,6 @@ num_epochs = 10
|
|
172 |
batch_size = 4
|
173 |
rng = jax.random.PRNGKey(0)
|
174 |
|
175 |
-
# Debug print
|
176 |
-
print("Processed dataset info:")
|
177 |
-
print(processed_dataset)
|
178 |
-
print("First batch:")
|
179 |
-
first_batch = next(iter(processed_dataset.batch(batch_size)))
|
180 |
-
print(f"Batch keys: {first_batch.keys()}")
|
181 |
-
print(f"Type of pixel_values: {type(first_batch['pixel_values'])}")
|
182 |
-
if isinstance(first_batch['pixel_values'], list):
|
183 |
-
print(f"Length of pixel_values list: {len(first_batch['pixel_values'])}")
|
184 |
-
if len(first_batch['pixel_values']) > 0:
|
185 |
-
print(f"Shape of first item in pixel_values: {np.array(first_batch['pixel_values'][0]).shape}")
|
186 |
-
|
187 |
-
# Convert the list of pixel values to a numpy array
|
188 |
-
first_batch['pixel_values'] = np.array(first_batch['pixel_values'])
|
189 |
-
print(f"Pixel values shape after conversion: {first_batch['pixel_values'].shape}")
|
190 |
-
|
191 |
for epoch in range(num_epochs):
|
192 |
epoch_loss = 0
|
193 |
num_batches = 0
|
@@ -202,7 +191,6 @@ for epoch in range(num_epochs):
|
|
202 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
203 |
|
204 |
|
205 |
-
|
206 |
# Save the fine-tuned model
|
207 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
208 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
55 |
# Resize and convert to RGB
|
56 |
image = image.convert("RGB").resize((512, 512))
|
57 |
# Convert to numpy array and normalize
|
58 |
+
image = np.array(image).astype(np.float32) / 255.0
|
59 |
# Ensure the image has the shape (3, height, width)
|
60 |
return image.transpose(2, 0, 1) # Change to channel-first format
|
61 |
|
62 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
63 |
|
64 |
|
65 |
+
|
66 |
# Load dataset from Hugging Face
|
67 |
dataset_name = "uruguayai/montevideo"
|
68 |
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
|
|
|
119 |
pixel_values = jnp.array(batch["pixel_values"])
|
120 |
batch_size = pixel_values.shape[0]
|
121 |
|
122 |
+
# Encode images to latent space
|
123 |
+
latents = pipeline.vae.encode(pixel_values).latent_dist.sample(rng)
|
124 |
+
latents = latents * pipeline.vae.config.scaling_factor
|
125 |
+
|
126 |
# Generate random noise
|
127 |
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
128 |
+
noise = jax.random.normal(noise_rng, latents.shape)
|
129 |
|
130 |
# Sample random timesteps
|
131 |
timesteps = jax.random.randint(
|
|
|
135 |
# Create scheduler state
|
136 |
scheduler_state = pipeline.scheduler.create_state()
|
137 |
|
138 |
+
# Add noise to latents using the scheduler
|
139 |
+
noisy_latents = pipeline.scheduler.add_noise(
|
140 |
scheduler_state,
|
141 |
+
original_samples=latents,
|
142 |
noise=noise,
|
143 |
timesteps=timesteps
|
144 |
)
|
145 |
|
146 |
# Generate random latents for text encoder
|
147 |
+
encoder_hidden_states = jax.random.normal(latents_rng, (batch_size, pipeline.text_encoder.config.hidden_size))
|
148 |
|
149 |
# Predict noise
|
150 |
model_output = state.apply_fn.apply(
|
151 |
{'params': params},
|
152 |
+
jnp.array(noisy_latents),
|
153 |
jnp.array(timesteps),
|
154 |
+
encoder_hidden_states=encoder_hidden_states,
|
155 |
train=True,
|
156 |
)
|
157 |
|
|
|
177 |
batch_size = 4
|
178 |
rng = jax.random.PRNGKey(0)
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
for epoch in range(num_epochs):
|
181 |
epoch_loss = 0
|
182 |
num_batches = 0
|
|
|
191 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
192 |
|
193 |
|
|
|
194 |
# Save the fine-tuned model
|
195 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
196 |
os.makedirs(output_dir, exist_ok=True)
|