Update pipeline.py
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -246,7 +246,7 @@ class SwDPipeline(StableDiffusion3Pipeline):
|
|
246 |
x0_pred = (latents - sigma * noise_pred)
|
247 |
if scales and i + 1 < len(scales):
|
248 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
|
249 |
-
noise = torch.randn(x0_pred.shape, generator=generator,
|
250 |
latents = (1 - sigma_next) * x0_pred + sigma_next * noise
|
251 |
|
252 |
if latents.dtype != latents_dtype:
|
|
|
246 |
x0_pred = (latents - sigma * noise_pred)
|
247 |
if scales and i + 1 < len(scales):
|
248 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
|
249 |
+
noise = torch.randn(x0_pred.shape, generator=generator, dtype=x0_pred.dtype).to(x0_pred.device)
|
250 |
latents = (1 - sigma_next) * x0_pred + sigma_next * noise
|
251 |
|
252 |
if latents.dtype != latents_dtype:
|