SpiridonSunRotator commited on
Commit
30bc5c0
·
verified ·
1 Parent(s): fd23b8c

Argument and dtype fix

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -4
pipeline.py CHANGED
@@ -244,11 +244,9 @@ class SwDPipeline(StableDiffusion3Pipeline):
244
  sigma = sigmas[i]
245
  sigma_next = sigmas[i + 1]
246
  x0_pred = (latents - sigma * noise_pred)
247
- try:
248
  x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
249
- except IndexError:
250
- x0_pred = x0_pred
251
- noise = torch.randn(x0_pred.shape, generator=generator).to('cuda').half()
252
  latents = (1 - sigma_next) * x0_pred + sigma_next * noise
253
 
254
  if latents.dtype != latents_dtype:
 
244
  sigma = sigmas[i]
245
  sigma_next = sigmas[i + 1]
246
  x0_pred = (latents - sigma * noise_pred)
247
+ if scales and i + 1 < len(scales):
248
  x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
249
+ noise = torch.randn(x0_pred.shape, generator=generator, device=device, dtype=x0_pred.dtype)
 
 
250
  latents = (1 - sigma_next) * x0_pred + sigma_next * noise
251
 
252
  if latents.dtype != latents_dtype: