Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,16 @@ from PIL import Image
|
|
11 |
import numpy as np
|
12 |
import gc
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Force JAX to use CPU
|
15 |
jax.config.update('jax_platform_name', 'cpu')
|
16 |
|
@@ -50,6 +60,13 @@ pipeline, params = get_model(model_id, "flax")
|
|
50 |
# Extract UNet from pipeline
|
51 |
unet = pipeline.unet
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
# Load and preprocess your dataset
|
54 |
def preprocess_images(examples):
|
55 |
def process_image(image):
|
@@ -97,6 +114,10 @@ def clear_jit_cache():
|
|
97 |
# Training function
|
98 |
def train_step(state, batch, rng):
|
99 |
def compute_loss(params, pixel_values, rng):
|
|
|
|
|
|
|
|
|
100 |
# Ensure pixel_values are float32
|
101 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
102 |
|
@@ -111,13 +132,14 @@ def train_step(state, batch, rng):
|
|
111 |
# Generate random noise
|
112 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
113 |
|
114 |
-
# Sample random timesteps
|
115 |
timesteps = jax.random.randint(
|
116 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
117 |
)
|
118 |
|
119 |
-
|
120 |
-
|
|
|
121 |
|
122 |
# Add noise to latents
|
123 |
noisy_latents = pipeline.scheduler.add_noise(
|
@@ -157,9 +179,10 @@ def train_step(state, batch, rng):
|
|
157 |
# Initialize training state
|
158 |
learning_rate = 1e-5
|
159 |
optimizer = optax.adam(learning_rate)
|
|
|
160 |
state = train_state.TrainState.create(
|
161 |
-
apply_fn=unet.__call__,
|
162 |
-
params=
|
163 |
tx=optimizer,
|
164 |
)
|
165 |
|
|
|
11 |
import numpy as np
|
12 |
import gc
|
13 |
|
14 |
+
|
15 |
+
from diffusers.schedulers import PNDMScheduler
|
16 |
+
|
17 |
+
class CustomPNDMScheduler(PNDMScheduler):
|
18 |
+
def add_noise(self, state, original_samples, noise, timesteps):
|
19 |
+
# Explicitly cast timesteps to int32
|
20 |
+
timesteps = timesteps.astype(jnp.int32)
|
21 |
+
return super().add_noise(state, original_samples, noise, timesteps)
|
22 |
+
|
23 |
+
|
24 |
# Force JAX to use CPU
|
25 |
jax.config.update('jax_platform_name', 'cpu')
|
26 |
|
|
|
60 |
# Extract UNet from pipeline
|
61 |
unet = pipeline.unet
|
62 |
|
63 |
+
|
64 |
+
|
65 |
+
# After loading the pipeline
|
66 |
+
custom_scheduler = CustomPNDMScheduler.from_config(pipeline.scheduler.config)
|
67 |
+
pipeline.scheduler = custom_scheduler
|
68 |
+
|
69 |
+
|
70 |
# Load and preprocess your dataset
|
71 |
def preprocess_images(examples):
|
72 |
def process_image(image):
|
|
|
114 |
# Training function
|
115 |
def train_step(state, batch, rng):
|
116 |
def compute_loss(params, pixel_values, rng):
|
117 |
+
print("pixel_values dtype:", pixel_values.dtype)
|
118 |
+
print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
|
119 |
+
print("rng dtype:", rng.dtype)
|
120 |
+
|
121 |
# Ensure pixel_values are float32
|
122 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
123 |
|
|
|
132 |
# Generate random noise
|
133 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
134 |
|
135 |
+
# Sample random timesteps
|
136 |
timesteps = jax.random.randint(
|
137 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
138 |
)
|
139 |
|
140 |
+
print("timesteps dtype:", timesteps.dtype)
|
141 |
+
print("latents dtype:", latents.dtype)
|
142 |
+
print("noise dtype:", noise.dtype)
|
143 |
|
144 |
# Add noise to latents
|
145 |
noisy_latents = pipeline.scheduler.add_noise(
|
|
|
179 |
# Initialize training state
|
180 |
learning_rate = 1e-5
|
181 |
optimizer = optax.adam(learning_rate)
|
182 |
+
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
|
183 |
state = train_state.TrainState.create(
|
184 |
+
apply_fn=unet.__call__,
|
185 |
+
params=float32_params,
|
186 |
tx=optimizer,
|
187 |
)
|
188 |
|