Convert tile to torch.cuda.FloatTensor if it's not already of that type
Browse files
opensora/models/ae/videobase/causal_vae/modeling_causalvae.py
CHANGED
@@ -610,6 +610,11 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
610 |
i : i + self.tile_latent_min_size,
|
611 |
j : j + self.tile_latent_min_size,
|
612 |
]
|
|
|
|
|
|
|
|
|
|
|
613 |
tile = self.post_quant_conv(tile)
|
614 |
decoded = self.decoder(tile)
|
615 |
row.append(decoded)
|
|
|
610 |
i : i + self.tile_latent_min_size,
|
611 |
j : j + self.tile_latent_min_size,
|
612 |
]
|
613 |
+
|
614 |
+
# Convert tile to torch.cuda.FloatTensor if it's not already of that type
|
615 |
+
if tile.dtype != torch.float32:
|
616 |
+
tile = tile.float()
|
617 |
+
|
618 |
tile = self.post_quant_conv(tile)
|
619 |
decoded = self.decoder(tile)
|
620 |
row.append(decoded)
|