Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ import numpy as np
|
|
15 |
# Custom Scheduler
|
16 |
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
|
17 |
def add_noise(self, state, original_samples, noise, timesteps):
|
18 |
-
# Explicitly cast timesteps to int32
|
19 |
timesteps = timesteps.astype(jnp.int32)
|
20 |
return super().add_noise(state, original_samples, noise, timesteps)
|
21 |
|
@@ -97,10 +96,8 @@ def train_step(state, batch, rng):
|
|
97 |
print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
|
98 |
print("rng dtype:", rng.dtype)
|
99 |
|
100 |
-
# Ensure pixel_values are float32
|
101 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
102 |
|
103 |
-
# Encode images to latent space
|
104 |
latents = pipeline.vae.apply(
|
105 |
{"params": params["vae"]},
|
106 |
pixel_values,
|
@@ -108,10 +105,8 @@ def train_step(state, batch, rng):
|
|
108 |
).latent_dist.sample(rng)
|
109 |
latents = latents * jnp.float32(0.18215)
|
110 |
|
111 |
-
# Generate random noise
|
112 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
113 |
|
114 |
-
# Sample random timesteps
|
115 |
timesteps = jax.random.randint(
|
116 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
117 |
)
|
@@ -119,8 +114,8 @@ def train_step(state, batch, rng):
|
|
119 |
print("timesteps dtype:", timesteps.dtype)
|
120 |
print("latents dtype:", latents.dtype)
|
121 |
print("noise dtype:", noise.dtype)
|
|
|
122 |
|
123 |
-
# Add noise to latents
|
124 |
noisy_latents = pipeline.scheduler.add_noise(
|
125 |
pipeline.scheduler.create_state(),
|
126 |
original_samples=latents,
|
@@ -128,14 +123,12 @@ def train_step(state, batch, rng):
|
|
128 |
timesteps=timesteps
|
129 |
)
|
130 |
|
131 |
-
# Generate random encoder hidden states (simulating text embeddings)
|
132 |
encoder_hidden_states = jax.random.normal(
|
133 |
rng,
|
134 |
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
135 |
dtype=jnp.float32
|
136 |
)
|
137 |
|
138 |
-
# Predict noise
|
139 |
model_output = unet.apply(
|
140 |
{'params': params["unet"]},
|
141 |
noisy_latents,
|
@@ -144,7 +137,6 @@ def train_step(state, batch, rng):
|
|
144 |
train=True,
|
145 |
)
|
146 |
|
147 |
-
# Compute loss
|
148 |
return jnp.mean((model_output - noise) ** 2)
|
149 |
|
150 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|
|
|
15 |
# Custom Scheduler
|
16 |
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
|
17 |
def add_noise(self, state, original_samples, noise, timesteps):
|
|
|
18 |
timesteps = timesteps.astype(jnp.int32)
|
19 |
return super().add_noise(state, original_samples, noise, timesteps)
|
20 |
|
|
|
96 |
print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
|
97 |
print("rng dtype:", rng.dtype)
|
98 |
|
|
|
99 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
100 |
|
|
|
101 |
latents = pipeline.vae.apply(
|
102 |
{"params": params["vae"]},
|
103 |
pixel_values,
|
|
|
105 |
).latent_dist.sample(rng)
|
106 |
latents = latents * jnp.float32(0.18215)
|
107 |
|
|
|
108 |
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
|
109 |
|
|
|
110 |
timesteps = jax.random.randint(
|
111 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
112 |
)
|
|
|
114 |
print("timesteps dtype:", timesteps.dtype)
|
115 |
print("latents dtype:", latents.dtype)
|
116 |
print("noise dtype:", noise.dtype)
|
117 |
+
print("latents shape:", latents.shape)
|
118 |
|
|
|
119 |
noisy_latents = pipeline.scheduler.add_noise(
|
120 |
pipeline.scheduler.create_state(),
|
121 |
original_samples=latents,
|
|
|
123 |
timesteps=timesteps
|
124 |
)
|
125 |
|
|
|
126 |
encoder_hidden_states = jax.random.normal(
|
127 |
rng,
|
128 |
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
|
129 |
dtype=jnp.float32
|
130 |
)
|
131 |
|
|
|
132 |
model_output = unet.apply(
|
133 |
{'params': params["unet"]},
|
134 |
noisy_latents,
|
|
|
137 |
train=True,
|
138 |
)
|
139 |
|
|
|
140 |
return jnp.mean((model_output - noise) ** 2)
|
141 |
|
142 |
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
|