Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,24 @@
|
|
1 |
import jax
|
2 |
import jax.numpy as jnp
|
|
|
3 |
from flax.training import train_state
|
4 |
import optax
|
5 |
from diffusers import FlaxStableDiffusionPipeline
|
|
|
6 |
from datasets import load_dataset
|
7 |
from tqdm.auto import tqdm
|
8 |
import os
|
9 |
import pickle
|
10 |
from PIL import Image
|
11 |
import numpy as np
|
12 |
-
import gc
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
class CustomPNDMScheduler(PNDMScheduler):
|
18 |
def add_noise(self, state, original_samples, noise, timesteps):
|
19 |
# Explicitly cast timesteps to int32
|
20 |
timesteps = timesteps.astype(jnp.int32)
|
21 |
return super().add_noise(state, original_samples, noise, timesteps)
|
22 |
|
23 |
-
|
24 |
-
# Force JAX to use CPU
|
25 |
-
jax.config.update('jax_platform_name', 'cpu')
|
26 |
-
|
27 |
-
print("Using CPU for computations")
|
28 |
-
|
29 |
# Set up cache directories
|
30 |
cache_dir = "/tmp/huggingface_cache"
|
31 |
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
|
@@ -47,7 +40,7 @@ def get_model(model_id, revision):
|
|
47 |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
48 |
model_id,
|
49 |
revision=revision,
|
50 |
-
dtype=jnp.float32,
|
51 |
)
|
52 |
with open(model_cache_file, 'wb') as f:
|
53 |
pickle.dump((pipeline, params), f)
|
@@ -57,15 +50,12 @@ def get_model(model_id, revision):
|
|
57 |
model_id = "CompVis/stable-diffusion-v1-4"
|
58 |
pipeline, params = get_model(model_id, "flax")
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
# After loading the pipeline
|
66 |
-
custom_scheduler = CustomPNDMScheduler.from_config(pipeline.scheduler.config)
|
67 |
pipeline.scheduler = custom_scheduler
|
68 |
|
|
|
|
|
69 |
|
70 |
# Load and preprocess your dataset
|
71 |
def preprocess_images(examples):
|
@@ -87,191 +77,21 @@ dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
|
|
87 |
print(f"Dataset name: {dataset_name}")
|
88 |
print(f"Dataset cache file: {dataset_cache_file}")
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
with open(dataset_cache_file, 'wb') as f:
|
101 |
-
pickle.dump(processed_dataset, f)
|
102 |
-
|
103 |
-
print(f"Processed dataset size: {len(processed_dataset)}")
|
104 |
-
|
105 |
-
except Exception as e:
|
106 |
-
print(f"Error loading or processing dataset: {str(e)}")
|
107 |
-
raise ValueError("Unable to load or process the dataset.")
|
108 |
|
109 |
-
|
110 |
-
def clear_jit_cache():
|
111 |
-
jax.clear_caches()
|
112 |
-
gc.collect()
|
113 |
|
114 |
# Training function
|
115 |
def train_step(state, batch, rng):
|
116 |
def compute_loss(params, pixel_values, rng):
|
117 |
print("pixel_values dtype:", pixel_values.dtype)
|
118 |
-
print("params dtypes:", jax.tree_map
|
119 |
-
print("rng dtype:", rng.dtype)
|
120 |
-
|
121 |
-
# Ensure pixel_values are float32
|
122 |
-
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
123 |
-
|
124 |
-
# Encode images to latent space
|
125 |
-
latents = pipeline.vae.apply(
|
126 |
-
{"params": params["vae"]},
|
127 |
-
pixel_values,
|
128 |
-
method=pipeline.vae.encode
|
129 |
-
).latent_dist.sample(rng)
|
130 |
-
latents = latents * jnp.float32(0.18215)
|
131 |
-
|
132 |
-
# Generate random noise
|
133 |
-
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
134 |
-
|
135 |
-
# Sample random timesteps
|
136 |
-
timesteps = jax.random.randint(
|
137 |
-
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
138 |
-
)
|
139 |
-
|
140 |
-
print("timesteps dtype:", timesteps.dtype)
|
141 |
-
print("latents dtype:", latents.dtype)
|
142 |
-
print("noise dtype:", noise.dtype)
|
143 |
-
|
144 |
-
# Add noise to latents
|
145 |
-
noisy_latents = pipeline.scheduler.add_noise(
|
146 |
-
pipeline.scheduler.create_state(),
|
147 |
-
original_samples=latents,
|
148 |
-
noise=noise,
|
149 |
-
timesteps=timesteps
|
150 |
-
)
|
151 |
-
|
152 |
-
# Generate random encoder hidden states (simulating text embeddings)
|
153 |
-
encoder_hidden_states = jax.random.normal(
|
154 |
-
rng,
|
155 |
-
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
156 |
-
dtype=jnp.float32
|
157 |
-
)
|
158 |
-
|
159 |
-
# Predict noise
|
160 |
-
model_output = state.apply_fn(
|
161 |
-
{'params': params["unet"]},
|
162 |
-
noisy_latents,
|
163 |
-
timesteps,
|
164 |
-
encoder_hidden_states=encoder_hidden_states,
|
165 |
-
train=True,
|
166 |
-
)
|
167 |
-
|
168 |
-
# Compute loss
|
169 |
-
return jnp.mean((model_output - noise) ** 2)
|
170 |
-
|
171 |
-
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
172 |
-
rng, step_rng = jax.random.split(rng)
|
173 |
-
|
174 |
-
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
175 |
-
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
|
176 |
-
state = state.apply_gradients(grads=grads)
|
177 |
-
return state, loss
|
178 |
-
|
179 |
-
# Initialize training state
|
180 |
-
learning_rate = 1e-5
|
181 |
-
optimizer = optax.adam(learning_rate)
|
182 |
-
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
183 |
-
state = train_state.TrainState.create(
|
184 |
-
apply_fn=unet.__call__,
|
185 |
-
params=float32_params,
|
186 |
-
tx=optimizer,
|
187 |
-
)
|
188 |
-
|
189 |
-
# Modify the train_step function
|
190 |
-
def train_step(state, batch, rng):
|
191 |
-
def compute_loss(params, pixel_values, rng):
|
192 |
-
# Ensure pixel_values are float32
|
193 |
-
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
194 |
-
|
195 |
-
# Encode images to latent space
|
196 |
-
latents = pipeline.vae.apply(
|
197 |
-
{"params": params["vae"]},
|
198 |
-
pixel_values,
|
199 |
-
method=pipeline.vae.encode
|
200 |
-
).latent_dist.sample(rng)
|
201 |
-
latents = latents * jnp.float32(0.18215)
|
202 |
-
|
203 |
-
# Generate random noise
|
204 |
-
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
205 |
-
|
206 |
-
# Sample random timesteps
|
207 |
-
timesteps = jax.random.randint(
|
208 |
-
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
209 |
-
)
|
210 |
-
timesteps = jnp.array(timesteps, dtype=jnp.float32)
|
211 |
-
|
212 |
-
# Add noise to latents
|
213 |
-
noisy_latents = pipeline.scheduler.add_noise(
|
214 |
-
pipeline.scheduler.create_state(),
|
215 |
-
original_samples=latents,
|
216 |
-
noise=noise,
|
217 |
-
timesteps=timesteps
|
218 |
-
)
|
219 |
-
|
220 |
-
# Generate random encoder hidden states (simulating text embeddings)
|
221 |
-
encoder_hidden_states = jax.random.normal(
|
222 |
-
rng,
|
223 |
-
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
224 |
-
dtype=jnp.float32
|
225 |
-
)
|
226 |
-
|
227 |
-
# Predict noise
|
228 |
-
model_output = state.apply_fn(
|
229 |
-
{'params': params["unet"]},
|
230 |
-
noisy_latents,
|
231 |
-
timesteps,
|
232 |
-
encoder_hidden_states=encoder_hidden_states,
|
233 |
-
train=True,
|
234 |
-
)
|
235 |
-
|
236 |
-
# Compute loss
|
237 |
-
return jnp.mean((model_output - noise) ** 2)
|
238 |
-
|
239 |
-
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
240 |
-
rng, step_rng = jax.random.split(rng)
|
241 |
-
|
242 |
-
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
243 |
-
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
|
244 |
-
state = state.apply_gradients(grads=grads)
|
245 |
-
return state, loss
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
# Training loop (remains the same)
|
250 |
-
num_epochs = 3
|
251 |
-
batch_size = 1
|
252 |
-
rng = jax.random.PRNGKey(0)
|
253 |
-
|
254 |
-
for epoch in range(num_epochs):
|
255 |
-
epoch_loss = 0
|
256 |
-
num_batches = 0
|
257 |
-
for batch in tqdm(processed_dataset.batch(batch_size)):
|
258 |
-
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
|
259 |
-
rng, step_rng = jax.random.split(rng)
|
260 |
-
state, loss = train_step(state, batch, step_rng)
|
261 |
-
epoch_loss += loss
|
262 |
-
num_batches += 1
|
263 |
-
|
264 |
-
if num_batches % 10 == 0:
|
265 |
-
clear_jit_cache()
|
266 |
-
|
267 |
-
avg_loss = epoch_loss / num_batches
|
268 |
-
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
|
269 |
-
clear_jit_cache()
|
270 |
-
|
271 |
-
|
272 |
-
# Save the fine-tuned model
|
273 |
-
output_dir = "/tmp/montevideo_fine_tuned_model"
|
274 |
-
os.makedirs(output_dir, exist_ok=True)
|
275 |
-
unet.save_pretrained(output_dir, params=state.params["unet"])
|
276 |
-
|
277 |
-
print(f"Model saved to {output_dir}")
|
|
|
1 |
import jax
|
2 |
import jax.numpy as jnp
|
3 |
+
from flax.jax_utils import replicate
|
4 |
from flax.training import train_state
|
5 |
import optax
|
6 |
from diffusers import FlaxStableDiffusionPipeline
|
7 |
+
from diffusers.schedulers import FlaxPNDMScheduler
|
8 |
from datasets import load_dataset
|
9 |
from tqdm.auto import tqdm
|
10 |
import os
|
11 |
import pickle
|
12 |
from PIL import Image
|
13 |
import numpy as np
|
|
|
14 |
|
15 |
+
# Custom Scheduler
|
16 |
+
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
|
|
|
|
|
17 |
def add_noise(self, state, original_samples, noise, timesteps):
|
18 |
# Explicitly cast timesteps to int32
|
19 |
timesteps = timesteps.astype(jnp.int32)
|
20 |
return super().add_noise(state, original_samples, noise, timesteps)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# Set up cache directories
|
23 |
cache_dir = "/tmp/huggingface_cache"
|
24 |
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
|
|
|
40 |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
41 |
model_id,
|
42 |
revision=revision,
|
43 |
+
dtype=jnp.float32,
|
44 |
)
|
45 |
with open(model_cache_file, 'wb') as f:
|
46 |
pickle.dump((pipeline, params), f)
|
|
|
50 |
model_id = "CompVis/stable-diffusion-v1-4"
|
51 |
pipeline, params = get_model(model_id, "flax")
|
52 |
|
53 |
+
# Use custom scheduler
|
54 |
+
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
|
|
|
|
|
|
|
|
|
|
55 |
pipeline.scheduler = custom_scheduler
|
56 |
|
57 |
+
# Extract UNet from pipeline
|
58 |
+
unet = pipeline.unet
|
59 |
|
60 |
# Load and preprocess your dataset
|
61 |
def preprocess_images(examples):
|
|
|
77 |
print(f"Dataset name: {dataset_name}")
|
78 |
print(f"Dataset cache file: {dataset_cache_file}")
|
79 |
|
80 |
+
if os.path.exists(dataset_cache_file):
|
81 |
+
print("Loading dataset from cache...")
|
82 |
+
with open(dataset_cache_file, 'rb') as f:
|
83 |
+
processed_dataset = pickle.load(f)
|
84 |
+
else:
|
85 |
+
print("Processing dataset...")
|
86 |
+
dataset = load_dataset(dataset_name)
|
87 |
+
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
|
88 |
+
with open(dataset_cache_file, 'wb') as f:
|
89 |
+
pickle.dump(processed_dataset, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
print(f"Processed dataset size: {len(processed_dataset)}")
|
|
|
|
|
|
|
92 |
|
93 |
# Training function
|
94 |
def train_step(state, batch, rng):
|
95 |
def compute_loss(params, pixel_values, rng):
|
96 |
print("pixel_values dtype:", pixel_values.dtype)
|
97 |
+
print("params dtypes:", jax.tree_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|