Update app.py
Browse files
app.py
CHANGED
@@ -149,14 +149,65 @@ def train_step(state, batch, rng):
|
|
149 |
learning_rate = 1e-5
|
150 |
optimizer = optax.adam(learning_rate)
|
151 |
state = train_state.TrainState.create(
|
152 |
-
apply_fn=unet.__call__,
|
153 |
-
params=
|
154 |
tx=optimizer,
|
155 |
)
|
156 |
|
157 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
num_epochs = 3
|
159 |
-
batch_size = 1
|
160 |
rng = jax.random.PRNGKey(0)
|
161 |
|
162 |
for epoch in range(num_epochs):
|
|
|
149 |
learning_rate = 1e-5
|
150 |
optimizer = optax.adam(learning_rate)
|
151 |
state = train_state.TrainState.create(
|
152 |
+
apply_fn=unet.__call__, # Use __call__ directly
|
153 |
+
params=params, # Pass all params
|
154 |
tx=optimizer,
|
155 |
)
|
156 |
|
157 |
+
# Modify the train_step function
|
158 |
+
def train_step(state, batch, rng):
|
159 |
+
def compute_loss(params, pixel_values, rng):
|
160 |
+
# Encode images to latent space
|
161 |
+
latents = pipeline.vae.apply(
|
162 |
+
{"params": params["vae"]},
|
163 |
+
pixel_values,
|
164 |
+
method=pipeline.vae.encode
|
165 |
+
).latent_dist.sample(rng)
|
166 |
+
latents = latents * 0.18215
|
167 |
+
|
168 |
+
# Generate random noise
|
169 |
+
noise = jax.random.normal(rng, latents.shape)
|
170 |
+
|
171 |
+
# Sample random timesteps
|
172 |
+
timesteps = jax.random.randint(
|
173 |
+
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
174 |
+
)
|
175 |
+
|
176 |
+
# Add noise to latents
|
177 |
+
noisy_latents = pipeline.scheduler.add_noise(
|
178 |
+
pipeline.scheduler.create_state(),
|
179 |
+
original_samples=latents,
|
180 |
+
noise=noise,
|
181 |
+
timesteps=timesteps
|
182 |
+
)
|
183 |
+
|
184 |
+
# Generate random encoder hidden states (simulating text embeddings)
|
185 |
+
encoder_hidden_states = jax.random.normal(
|
186 |
+
rng,
|
187 |
+
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
188 |
+
)
|
189 |
+
|
190 |
+
# Predict noise
|
191 |
+
model_output = state.apply_fn(
|
192 |
+
{'params': params["unet"]},
|
193 |
+
noisy_latents,
|
194 |
+
timesteps,
|
195 |
+
encoder_hidden_states=encoder_hidden_states,
|
196 |
+
train=True,
|
197 |
+
)
|
198 |
+
|
199 |
+
# Compute loss
|
200 |
+
return jnp.mean((model_output - noise) ** 2)
|
201 |
+
|
202 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
203 |
+
rng, step_rng = jax.random.split(rng)
|
204 |
+
loss, grads = grad_fn(state.params, batch["pixel_values"], step_rng)
|
205 |
+
state = state.apply_gradients(grads=grads)
|
206 |
+
return state, loss
|
207 |
+
|
208 |
+
# Training loop (remains the same)
|
209 |
num_epochs = 3
|
210 |
+
batch_size = 1
|
211 |
rng = jax.random.PRNGKey(0)
|
212 |
|
213 |
for epoch in range(num_epochs):
|