Argument and dtype fix
Browse files- pipeline.py +2 -4
pipeline.py
CHANGED
@@ -244,11 +244,9 @@ class SwDPipeline(StableDiffusion3Pipeline):
|
|
244 |
sigma = sigmas[i]
|
245 |
sigma_next = sigmas[i + 1]
|
246 |
x0_pred = (latents - sigma * noise_pred)
|
247 |
-
|
248 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
|
249 |
-
|
250 |
-
x0_pred = x0_pred
|
251 |
-
noise = torch.randn(x0_pred.shape, generator=generator).to('cuda').half()
|
252 |
latents = (1 - sigma_next) * x0_pred + sigma_next * noise
|
253 |
|
254 |
if latents.dtype != latents_dtype:
|
|
|
244 |
sigma = sigmas[i]
|
245 |
sigma_next = sigmas[i + 1]
|
246 |
x0_pred = (latents - sigma * noise_pred)
|
247 |
+
if scales and i + 1 < len(scales):
|
248 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
|
249 |
+
noise = torch.randn(x0_pred.shape, generator=generator, device=device, dtype=x0_pred.dtype)
|
|
|
|
|
250 |
latents = (1 - sigma_next) * x0_pred + sigma_next * noise
|
251 |
|
252 |
if latents.dtype != latents_dtype:
|