uruguayai commited on
Commit
3518b5f
·
verified ·
1 Parent(s): acc7f4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -201
app.py CHANGED
@@ -1,31 +1,24 @@
1
  import jax
2
  import jax.numpy as jnp
 
3
  from flax.training import train_state
4
  import optax
5
  from diffusers import FlaxStableDiffusionPipeline
 
6
  from datasets import load_dataset
7
  from tqdm.auto import tqdm
8
  import os
9
  import pickle
10
  from PIL import Image
11
  import numpy as np
12
- import gc
13
 
14
-
15
- from diffusers.schedulers import PNDMScheduler
16
-
17
- class CustomPNDMScheduler(PNDMScheduler):
18
  def add_noise(self, state, original_samples, noise, timesteps):
19
  # Explicitly cast timesteps to int32
20
  timesteps = timesteps.astype(jnp.int32)
21
  return super().add_noise(state, original_samples, noise, timesteps)
22
 
23
-
24
- # Force JAX to use CPU
25
- jax.config.update('jax_platform_name', 'cpu')
26
-
27
- print("Using CPU for computations")
28
-
29
  # Set up cache directories
30
  cache_dir = "/tmp/huggingface_cache"
31
  model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
@@ -47,7 +40,7 @@ def get_model(model_id, revision):
47
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
48
  model_id,
49
  revision=revision,
50
- dtype=jnp.float32, # Use float32 for CPU
51
  )
52
  with open(model_cache_file, 'wb') as f:
53
  pickle.dump((pipeline, params), f)
@@ -57,15 +50,12 @@ def get_model(model_id, revision):
57
  model_id = "CompVis/stable-diffusion-v1-4"
58
  pipeline, params = get_model(model_id, "flax")
59
 
60
- # Extract UNet from pipeline
61
- unet = pipeline.unet
62
-
63
-
64
-
65
- # After loading the pipeline
66
- custom_scheduler = CustomPNDMScheduler.from_config(pipeline.scheduler.config)
67
  pipeline.scheduler = custom_scheduler
68
 
 
 
69
 
70
  # Load and preprocess your dataset
71
  def preprocess_images(examples):
@@ -87,191 +77,21 @@ dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
87
  print(f"Dataset name: {dataset_name}")
88
  print(f"Dataset cache file: {dataset_cache_file}")
89
 
90
- try:
91
- if os.path.exists(dataset_cache_file):
92
- print("Loading dataset from cache...")
93
- with open(dataset_cache_file, 'rb') as f:
94
- processed_dataset = pickle.load(f)
95
- else:
96
- print("Loading dataset from Hugging Face...")
97
- dataset = load_dataset(dataset_name, split="train[:500]") # Load only first 500 samples
98
- print("Processing dataset...")
99
- processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
100
- with open(dataset_cache_file, 'wb') as f:
101
- pickle.dump(processed_dataset, f)
102
-
103
- print(f"Processed dataset size: {len(processed_dataset)}")
104
-
105
- except Exception as e:
106
- print(f"Error loading or processing dataset: {str(e)}")
107
- raise ValueError("Unable to load or process the dataset.")
108
 
109
- # Function to clear JIT cache
110
- def clear_jit_cache():
111
- jax.clear_caches()
112
- gc.collect()
113
 
114
  # Training function
115
  def train_step(state, batch, rng):
116
  def compute_loss(params, pixel_values, rng):
117
  print("pixel_values dtype:", pixel_values.dtype)
118
- print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
119
- print("rng dtype:", rng.dtype)
120
-
121
- # Ensure pixel_values are float32
122
- pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
123
-
124
- # Encode images to latent space
125
- latents = pipeline.vae.apply(
126
- {"params": params["vae"]},
127
- pixel_values,
128
- method=pipeline.vae.encode
129
- ).latent_dist.sample(rng)
130
- latents = latents * jnp.float32(0.18215)
131
-
132
- # Generate random noise
133
- noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
134
-
135
- # Sample random timesteps
136
- timesteps = jax.random.randint(
137
- rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
138
- )
139
-
140
- print("timesteps dtype:", timesteps.dtype)
141
- print("latents dtype:", latents.dtype)
142
- print("noise dtype:", noise.dtype)
143
-
144
- # Add noise to latents
145
- noisy_latents = pipeline.scheduler.add_noise(
146
- pipeline.scheduler.create_state(),
147
- original_samples=latents,
148
- noise=noise,
149
- timesteps=timesteps
150
- )
151
-
152
- # Generate random encoder hidden states (simulating text embeddings)
153
- encoder_hidden_states = jax.random.normal(
154
- rng,
155
- (latents.shape[0], pipeline.text_encoder.config.hidden_size),
156
- dtype=jnp.float32
157
- )
158
-
159
- # Predict noise
160
- model_output = state.apply_fn(
161
- {'params': params["unet"]},
162
- noisy_latents,
163
- timesteps,
164
- encoder_hidden_states=encoder_hidden_states,
165
- train=True,
166
- )
167
-
168
- # Compute loss
169
- return jnp.mean((model_output - noise) ** 2)
170
-
171
- grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
172
- rng, step_rng = jax.random.split(rng)
173
-
174
- grads = grad_fn(state.params, batch["pixel_values"], step_rng)
175
- loss = compute_loss(state.params, batch["pixel_values"], step_rng)
176
- state = state.apply_gradients(grads=grads)
177
- return state, loss
178
-
179
- # Initialize training state
180
- learning_rate = 1e-5
181
- optimizer = optax.adam(learning_rate)
182
- float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
183
- state = train_state.TrainState.create(
184
- apply_fn=unet.__call__,
185
- params=float32_params,
186
- tx=optimizer,
187
- )
188
-
189
- # Modify the train_step function
190
- def train_step(state, batch, rng):
191
- def compute_loss(params, pixel_values, rng):
192
- # Ensure pixel_values are float32
193
- pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
194
-
195
- # Encode images to latent space
196
- latents = pipeline.vae.apply(
197
- {"params": params["vae"]},
198
- pixel_values,
199
- method=pipeline.vae.encode
200
- ).latent_dist.sample(rng)
201
- latents = latents * jnp.float32(0.18215)
202
-
203
- # Generate random noise
204
- noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
205
-
206
- # Sample random timesteps
207
- timesteps = jax.random.randint(
208
- rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
209
- )
210
- timesteps = jnp.array(timesteps, dtype=jnp.float32)
211
-
212
- # Add noise to latents
213
- noisy_latents = pipeline.scheduler.add_noise(
214
- pipeline.scheduler.create_state(),
215
- original_samples=latents,
216
- noise=noise,
217
- timesteps=timesteps
218
- )
219
-
220
- # Generate random encoder hidden states (simulating text embeddings)
221
- encoder_hidden_states = jax.random.normal(
222
- rng,
223
- (latents.shape[0], pipeline.text_encoder.config.hidden_size),
224
- dtype=jnp.float32
225
- )
226
-
227
- # Predict noise
228
- model_output = state.apply_fn(
229
- {'params': params["unet"]},
230
- noisy_latents,
231
- timesteps,
232
- encoder_hidden_states=encoder_hidden_states,
233
- train=True,
234
- )
235
-
236
- # Compute loss
237
- return jnp.mean((model_output - noise) ** 2)
238
-
239
- grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
240
- rng, step_rng = jax.random.split(rng)
241
-
242
- grads = grad_fn(state.params, batch["pixel_values"], step_rng)
243
- loss = compute_loss(state.params, batch["pixel_values"], step_rng)
244
- state = state.apply_gradients(grads=grads)
245
- return state, loss
246
-
247
-
248
-
249
- # Training loop (remains the same)
250
- num_epochs = 3
251
- batch_size = 1
252
- rng = jax.random.PRNGKey(0)
253
-
254
- for epoch in range(num_epochs):
255
- epoch_loss = 0
256
- num_batches = 0
257
- for batch in tqdm(processed_dataset.batch(batch_size)):
258
- batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
259
- rng, step_rng = jax.random.split(rng)
260
- state, loss = train_step(state, batch, step_rng)
261
- epoch_loss += loss
262
- num_batches += 1
263
-
264
- if num_batches % 10 == 0:
265
- clear_jit_cache()
266
-
267
- avg_loss = epoch_loss / num_batches
268
- print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
269
- clear_jit_cache()
270
-
271
-
272
- # Save the fine-tuned model
273
- output_dir = "/tmp/montevideo_fine_tuned_model"
274
- os.makedirs(output_dir, exist_ok=True)
275
- unet.save_pretrained(output_dir, params=state.params["unet"])
276
-
277
- print(f"Model saved to {output_dir}")
 
1
  import jax
2
  import jax.numpy as jnp
3
+ from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
  from diffusers import FlaxStableDiffusionPipeline
7
+ from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
10
  import os
11
  import pickle
12
  from PIL import Image
13
  import numpy as np
 
14
 
15
+ # Custom Scheduler
16
+ class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
 
 
17
  def add_noise(self, state, original_samples, noise, timesteps):
18
  # Explicitly cast timesteps to int32
19
  timesteps = timesteps.astype(jnp.int32)
20
  return super().add_noise(state, original_samples, noise, timesteps)
21
 
 
 
 
 
 
 
22
  # Set up cache directories
23
  cache_dir = "/tmp/huggingface_cache"
24
  model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
 
40
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
41
  model_id,
42
  revision=revision,
43
+ dtype=jnp.float32,
44
  )
45
  with open(model_cache_file, 'wb') as f:
46
  pickle.dump((pipeline, params), f)
 
50
  model_id = "CompVis/stable-diffusion-v1-4"
51
  pipeline, params = get_model(model_id, "flax")
52
 
53
+ # Use custom scheduler
54
+ custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
 
 
 
 
 
55
  pipeline.scheduler = custom_scheduler
56
 
57
+ # Extract UNet from pipeline
58
+ unet = pipeline.unet
59
 
60
  # Load and preprocess your dataset
61
  def preprocess_images(examples):
 
77
  print(f"Dataset name: {dataset_name}")
78
  print(f"Dataset cache file: {dataset_cache_file}")
79
 
80
+ if os.path.exists(dataset_cache_file):
81
+ print("Loading dataset from cache...")
82
+ with open(dataset_cache_file, 'rb') as f:
83
+ processed_dataset = pickle.load(f)
84
+ else:
85
+ print("Processing dataset...")
86
+ dataset = load_dataset(dataset_name)
87
+ processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
88
+ with open(dataset_cache_file, 'wb') as f:
89
+ pickle.dump(processed_dataset, f)
 
 
 
 
 
 
 
 
90
 
91
+ print(f"Processed dataset size: {len(processed_dataset)}")
 
 
 
92
 
93
  # Training function
94
  def train_step(state, batch, rng):
95
  def compute_loss(params, pixel_values, rng):
96
  print("pixel_values dtype:", pixel_values.dtype)
97
+ print("params dtypes:", jax.tree_map