uruguayai commited on
Commit
acc7f4b
·
verified ·
1 Parent(s): cf50961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -5
app.py CHANGED
@@ -11,6 +11,16 @@ from PIL import Image
11
  import numpy as np
12
  import gc
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Force JAX to use CPU
15
  jax.config.update('jax_platform_name', 'cpu')
16
 
@@ -50,6 +60,13 @@ pipeline, params = get_model(model_id, "flax")
50
  # Extract UNet from pipeline
51
  unet = pipeline.unet
52
 
 
 
 
 
 
 
 
53
  # Load and preprocess your dataset
54
  def preprocess_images(examples):
55
  def process_image(image):
@@ -97,6 +114,10 @@ def clear_jit_cache():
97
  # Training function
98
  def train_step(state, batch, rng):
99
  def compute_loss(params, pixel_values, rng):
 
 
 
 
100
  # Ensure pixel_values are float32
101
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
102
 
@@ -111,13 +132,14 @@ def train_step(state, batch, rng):
111
  # Generate random noise
112
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
113
 
114
- # Sample random timesteps (keep as integers)
115
  timesteps = jax.random.randint(
116
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
117
  )
118
 
119
- # Explicitly cast timesteps to int32
120
- timesteps = timesteps.astype(jnp.int32)
 
121
 
122
  # Add noise to latents
123
  noisy_latents = pipeline.scheduler.add_noise(
@@ -157,9 +179,10 @@ def train_step(state, batch, rng):
157
  # Initialize training state
158
  learning_rate = 1e-5
159
  optimizer = optax.adam(learning_rate)
 
160
  state = train_state.TrainState.create(
161
- apply_fn=unet.__call__, # Use __call__ directly
162
- params=params, # Pass all params
163
  tx=optimizer,
164
  )
165
 
 
11
  import numpy as np
12
  import gc
13
 
14
+
15
+ from diffusers.schedulers import PNDMScheduler
16
+
17
+ class CustomPNDMScheduler(PNDMScheduler):
18
+ def add_noise(self, state, original_samples, noise, timesteps):
19
+ # Explicitly cast timesteps to int32
20
+ timesteps = timesteps.astype(jnp.int32)
21
+ return super().add_noise(state, original_samples, noise, timesteps)
22
+
23
+
24
  # Force JAX to use CPU
25
  jax.config.update('jax_platform_name', 'cpu')
26
 
 
60
  # Extract UNet from pipeline
61
  unet = pipeline.unet
62
 
63
+
64
+
65
+ # After loading the pipeline
66
+ custom_scheduler = CustomPNDMScheduler.from_config(pipeline.scheduler.config)
67
+ pipeline.scheduler = custom_scheduler
68
+
69
+
70
  # Load and preprocess your dataset
71
  def preprocess_images(examples):
72
  def process_image(image):
 
114
  # Training function
115
  def train_step(state, batch, rng):
116
  def compute_loss(params, pixel_values, rng):
117
+ print("pixel_values dtype:", pixel_values.dtype)
118
+ print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
119
+ print("rng dtype:", rng.dtype)
120
+
121
  # Ensure pixel_values are float32
122
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
123
 
 
132
  # Generate random noise
133
  noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
134
 
135
+ # Sample random timesteps
136
  timesteps = jax.random.randint(
137
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
138
  )
139
 
140
+ print("timesteps dtype:", timesteps.dtype)
141
+ print("latents dtype:", latents.dtype)
142
+ print("noise dtype:", noise.dtype)
143
 
144
  # Add noise to latents
145
  noisy_latents = pipeline.scheduler.add_noise(
 
179
  # Initialize training state
180
  learning_rate = 1e-5
181
  optimizer = optax.adam(learning_rate)
182
+ float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
183
  state = train_state.TrainState.create(
184
+ apply_fn=unet.__call__,
185
+ params=float32_params,
186
  tx=optimizer,
187
  )
188