quickjkee commited on
Commit
179b120
·
verified ·
1 Parent(s): 8923a6a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1 -1
pipeline.py CHANGED
@@ -246,7 +246,7 @@ class SwDPipeline(StableDiffusion3Pipeline):
246
  x0_pred = (latents - sigma * noise_pred)
247
  if scales and i + 1 < len(scales):
248
  x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
249
- noise = torch.randn(x0_pred.shape, generator=generator, device=device, dtype=x0_pred.dtype)
250
  latents = (1 - sigma_next) * x0_pred + sigma_next * noise
251
 
252
  if latents.dtype != latents_dtype:
 
246
  x0_pred = (latents - sigma * noise_pred)
247
  if scales and i + 1 < len(scales):
248
  x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
249
+ noise = torch.randn(x0_pred.shape, generator=generator, dtype=x0_pred.dtype).to(x0_pred.device)
250
  latents = (1 - sigma_next) * x0_pred + sigma_next * noise
251
 
252
  if latents.dtype != latents_dtype: