Update app.py
Browse files
app.py
CHANGED
@@ -129,10 +129,15 @@ def train_step(state, batch, rng):
|
|
129 |
dtype=jnp.float32
|
130 |
)
|
131 |
|
|
|
|
|
|
|
|
|
|
|
132 |
# Use the correct method to call the UNet
|
133 |
model_output = unet.apply(
|
134 |
{'params': params["unet"]},
|
135 |
-
|
136 |
jnp.array(timesteps, dtype=jnp.int32),
|
137 |
encoder_hidden_states,
|
138 |
train=True,
|
|
|
129 |
dtype=jnp.float32
|
130 |
)
|
131 |
|
132 |
+
# Ensure noisy_latents has the correct number of channels
|
133 |
+
if noisy_latents.shape[-1] != pipeline.unet.config.in_channels:
|
134 |
+
pad_width = [(0, 0)] * (noisy_latents.ndim - 1) + [(0, pipeline.unet.config.in_channels - noisy_latents.shape[-1])]
|
135 |
+
noisy_latents = jnp.pad(noisy_latents, pad_width, mode='constant')
|
136 |
+
|
137 |
# Use the correct method to call the UNet
|
138 |
model_output = unet.apply(
|
139 |
{'params': params["unet"]},
|
140 |
+
noisy_latents,
|
141 |
jnp.array(timesteps, dtype=jnp.int32),
|
142 |
encoder_hidden_states,
|
143 |
train=True,
|