uruguayai commited on
Commit
76dfe67
·
verified ·
1 Parent(s): a43470d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -129,10 +129,15 @@ def train_step(state, batch, rng):
129
  dtype=jnp.float32
130
  )
131
 
 
 
 
 
 
132
  # Use the correct method to call the UNet
133
  model_output = unet.apply(
134
  {'params': params["unet"]},
135
- jnp.array(noisy_latents),
136
  jnp.array(timesteps, dtype=jnp.int32),
137
  encoder_hidden_states,
138
  train=True,
 
129
  dtype=jnp.float32
130
  )
131
 
132
+ # Ensure noisy_latents has the correct number of channels
133
+ if noisy_latents.shape[-1] != pipeline.unet.config.in_channels:
134
+ pad_width = [(0, 0)] * (noisy_latents.ndim - 1) + [(0, pipeline.unet.config.in_channels - noisy_latents.shape[-1])]
135
+ noisy_latents = jnp.pad(noisy_latents, pad_width, mode='constant')
136
+
137
  # Use the correct method to call the UNet
138
  model_output = unet.apply(
139
  {'params': params["unet"]},
140
+ noisy_latents,
141
  jnp.array(timesteps, dtype=jnp.int32),
142
  encoder_hidden_states,
143
  train=True,