Update app.py
Browse files
app.py
CHANGED
@@ -10,9 +10,10 @@ import os
|
|
10 |
import pickle
|
11 |
from PIL import Image
|
12 |
import numpy as np
|
13 |
-
|
14 |
-
|
15 |
|
|
|
|
|
16 |
|
17 |
# Set up cache directories
|
18 |
cache_dir = "/tmp/huggingface_cache"
|
@@ -35,7 +36,7 @@ def get_model(model_id, revision):
|
|
35 |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
36 |
model_id,
|
37 |
revision=revision,
|
38 |
-
dtype=jnp.
|
39 |
)
|
40 |
with open(model_cache_file, 'wb') as f:
|
41 |
pickle.dump((pipeline, params), f)
|
@@ -58,14 +59,12 @@ def preprocess_images(examples):
|
|
58 |
# Resize and convert to RGB
|
59 |
image = image.convert("RGB").resize((512, 512))
|
60 |
# Convert to numpy array and normalize
|
61 |
-
image = np.array(image).astype(np.
|
62 |
# Ensure the image has the shape (3, height, width)
|
63 |
return image.transpose(2, 0, 1) # Change to channel-first format
|
64 |
|
65 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
66 |
|
67 |
-
|
68 |
-
|
69 |
# Load dataset from Hugging Face
|
70 |
dataset_name = "uruguayai/montevideo"
|
71 |
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
|
@@ -114,12 +113,16 @@ except Exception as e:
|
|
114 |
|
115 |
raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
|
116 |
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
# Training function
|
119 |
-
def train_step(state, batch, rng):
|
120 |
-
def compute_loss(params):
|
121 |
-
# Convert batch to JAX array
|
122 |
-
pixel_values = jnp.array(
|
123 |
batch_size = pixel_values.shape[0]
|
124 |
|
125 |
# Encode images to latent space
|
@@ -128,11 +131,11 @@ def train_step(state, batch, rng):
|
|
128 |
pixel_values,
|
129 |
method=pipeline.vae.encode
|
130 |
).latent_dist.sample(rng)
|
131 |
-
latents = latents * 0.18215 # scaling factor
|
132 |
|
133 |
# Generate random noise
|
134 |
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
135 |
-
noise = jax.random.normal(noise_rng, latents.shape)
|
136 |
|
137 |
# Sample random timesteps
|
138 |
timesteps = jax.random.randint(
|
@@ -151,13 +154,17 @@ def train_step(state, batch, rng):
|
|
151 |
)
|
152 |
|
153 |
# Generate random latents for text encoder
|
154 |
-
encoder_hidden_states = jax.random.normal(
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# Predict noise
|
157 |
model_output = state.apply_fn.apply(
|
158 |
{'params': params["unet"]},
|
159 |
-
jnp.array(noisy_latents),
|
160 |
-
jnp.array(timesteps),
|
161 |
encoder_hidden_states=encoder_hidden_states,
|
162 |
train=True,
|
163 |
)
|
@@ -166,12 +173,37 @@ def train_step(state, batch, rng):
|
|
166 |
loss = jnp.mean((model_output - noise) ** 2)
|
167 |
return loss
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
# Initialize training state
|
174 |
-
learning_rate = 1e-5
|
175 |
optimizer = optax.adam(learning_rate)
|
176 |
state = train_state.TrainState.create(
|
177 |
apply_fn=unet,
|
@@ -181,7 +213,7 @@ state = train_state.TrainState.create(
|
|
181 |
|
182 |
# Training loop
|
183 |
num_epochs = 10
|
184 |
-
batch_size =
|
185 |
rng = jax.random.PRNGKey(0)
|
186 |
|
187 |
for epoch in range(num_epochs):
|
@@ -189,19 +221,25 @@ for epoch in range(num_epochs):
|
|
189 |
num_batches = 0
|
190 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
191 |
# Convert the list of pixel values to a numpy array for each batch
|
192 |
-
batch['pixel_values'] = np.array(batch['pixel_values'])
|
193 |
rng, step_rng = jax.random.split(rng)
|
194 |
state, loss = train_step(state, batch, step_rng)
|
195 |
epoch_loss += loss
|
196 |
num_batches += 1
|
|
|
|
|
|
|
|
|
|
|
197 |
avg_loss = epoch_loss / num_batches
|
198 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
199 |
-
|
200 |
-
|
|
|
201 |
|
202 |
# Save the fine-tuned model
|
203 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
204 |
os.makedirs(output_dir, exist_ok=True)
|
205 |
-
unet.save_pretrained(output_dir, params=state.params)
|
206 |
|
207 |
print(f"Model saved to {output_dir}")
|
|
|
10 |
import pickle
|
11 |
from PIL import Image
|
12 |
import numpy as np
|
13 |
+
import gc
|
|
|
14 |
|
15 |
+
# Set default dtype to float16
|
16 |
+
jax.config.update("jax_default_dtype", "float16")
|
17 |
|
18 |
# Set up cache directories
|
19 |
cache_dir = "/tmp/huggingface_cache"
|
|
|
36 |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
37 |
model_id,
|
38 |
revision=revision,
|
39 |
+
dtype=jnp.float16,
|
40 |
)
|
41 |
with open(model_cache_file, 'wb') as f:
|
42 |
pickle.dump((pipeline, params), f)
|
|
|
59 |
# Resize and convert to RGB
|
60 |
image = image.convert("RGB").resize((512, 512))
|
61 |
# Convert to numpy array and normalize
|
62 |
+
image = np.array(image).astype(np.float16) / 255.0
|
63 |
# Ensure the image has the shape (3, height, width)
|
64 |
return image.transpose(2, 0, 1) # Change to channel-first format
|
65 |
|
66 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
67 |
|
|
|
|
|
68 |
# Load dataset from Hugging Face
|
69 |
dataset_name = "uruguayai/montevideo"
|
70 |
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
|
|
|
113 |
|
114 |
raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
|
115 |
|
116 |
+
# Function to clear JIT cache
|
117 |
+
def clear_jit_cache():
|
118 |
+
jax.clear_caches()
|
119 |
+
gc.collect()
|
120 |
|
121 |
+
# Training function with gradient accumulation
|
122 |
+
def train_step(state, batch, rng, grad_accumulation_steps=8):
|
123 |
+
def compute_loss(params, batch_slice, rng):
|
124 |
+
# Convert batch slice to JAX array
|
125 |
+
pixel_values = jnp.array(batch_slice["pixel_values"], dtype=jnp.float16)
|
126 |
batch_size = pixel_values.shape[0]
|
127 |
|
128 |
# Encode images to latent space
|
|
|
131 |
pixel_values,
|
132 |
method=pipeline.vae.encode
|
133 |
).latent_dist.sample(rng)
|
134 |
+
latents = latents * jnp.float16(0.18215) # scaling factor
|
135 |
|
136 |
# Generate random noise
|
137 |
noise_rng, timestep_rng, latents_rng = jax.random.split(rng, 3)
|
138 |
+
noise = jax.random.normal(noise_rng, latents.shape, dtype=jnp.float16)
|
139 |
|
140 |
# Sample random timesteps
|
141 |
timesteps = jax.random.randint(
|
|
|
154 |
)
|
155 |
|
156 |
# Generate random latents for text encoder
|
157 |
+
encoder_hidden_states = jax.random.normal(
|
158 |
+
latents_rng,
|
159 |
+
(batch_size, pipeline.text_encoder.config.hidden_size),
|
160 |
+
dtype=jnp.float16
|
161 |
+
)
|
162 |
|
163 |
# Predict noise
|
164 |
model_output = state.apply_fn.apply(
|
165 |
{'params': params["unet"]},
|
166 |
+
jnp.array(noisy_latents, dtype=jnp.float16),
|
167 |
+
jnp.array(timesteps, dtype=jnp.float16),
|
168 |
encoder_hidden_states=encoder_hidden_states,
|
169 |
train=True,
|
170 |
)
|
|
|
173 |
loss = jnp.mean((model_output - noise) ** 2)
|
174 |
return loss
|
175 |
|
176 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
177 |
+
|
178 |
+
# Split the batch into smaller chunks
|
179 |
+
batch_size = len(batch['pixel_values'])
|
180 |
+
chunk_size = batch_size // grad_accumulation_steps
|
181 |
+
|
182 |
+
# Initialize accumulated gradients
|
183 |
+
acc_grads = jax.tree_map(jnp.zeros_like, state.params)
|
184 |
+
acc_loss = jnp.float16(0.0)
|
185 |
+
|
186 |
+
for i in range(grad_accumulation_steps):
|
187 |
+
start_idx = i * chunk_size
|
188 |
+
end_idx = start_idx + chunk_size if i < grad_accumulation_steps - 1 else batch_size
|
189 |
+
|
190 |
+
batch_slice = {
|
191 |
+
'pixel_values': batch['pixel_values'][start_idx:end_idx]
|
192 |
+
}
|
193 |
+
|
194 |
+
rng, step_rng = jax.random.split(rng)
|
195 |
+
loss, grads = grad_fn(state.params, batch_slice, step_rng)
|
196 |
+
|
197 |
+
# Accumulate gradients and loss
|
198 |
+
acc_grads = jax.tree_map(lambda acc, g: acc + g / grad_accumulation_steps, acc_grads, grads)
|
199 |
+
acc_loss += loss / grad_accumulation_steps
|
200 |
+
|
201 |
+
# Update state with accumulated gradients
|
202 |
+
state = state.apply_gradients(grads=acc_grads)
|
203 |
+
return state, acc_loss
|
204 |
|
205 |
# Initialize training state
|
206 |
+
learning_rate = jnp.float16(1e-5)
|
207 |
optimizer = optax.adam(learning_rate)
|
208 |
state = train_state.TrainState.create(
|
209 |
apply_fn=unet,
|
|
|
213 |
|
214 |
# Training loop
|
215 |
num_epochs = 10
|
216 |
+
batch_size = 2 # Reduced batch size
|
217 |
rng = jax.random.PRNGKey(0)
|
218 |
|
219 |
for epoch in range(num_epochs):
|
|
|
221 |
num_batches = 0
|
222 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
223 |
# Convert the list of pixel values to a numpy array for each batch
|
224 |
+
batch['pixel_values'] = np.array(batch['pixel_values'], dtype=np.float16)
|
225 |
rng, step_rng = jax.random.split(rng)
|
226 |
state, loss = train_step(state, batch, step_rng)
|
227 |
epoch_loss += loss
|
228 |
num_batches += 1
|
229 |
+
|
230 |
+
# Clear JIT cache every 10 batches
|
231 |
+
if num_batches % 10 == 0:
|
232 |
+
clear_jit_cache()
|
233 |
+
|
234 |
avg_loss = epoch_loss / num_batches
|
235 |
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
236 |
+
|
237 |
+
# Clear JIT cache after each epoch
|
238 |
+
clear_jit_cache()
|
239 |
|
240 |
# Save the fine-tuned model
|
241 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
242 |
os.makedirs(output_dir, exist_ok=True)
|
243 |
+
unet.save_pretrained(output_dir, params=state.params["unet"])
|
244 |
|
245 |
print(f"Model saved to {output_dir}")
|