# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from typing import Tuple, Union import numpy as np import PIL.Image import torch from diffusers.image_processor import VaeImageProcessor from diffusers.utils.constants import ( DECODE_ENDPOINT_FLUX, DECODE_ENDPOINT_HUNYUAN_VIDEO, DECODE_ENDPOINT_SD_V1, DECODE_ENDPOINT_SD_XL, ) from diffusers.utils.remote_utils import ( remote_decode, ) from diffusers.utils.testing_utils import ( enable_full_determinism, slow, torch_all_close, torch_device, ) from diffusers.video_processor import VideoProcessor enable_full_determinism() class RemoteAutoencoderKLMixin: shape: Tuple[int, ...] = None out_hw: Tuple[int, int] = None endpoint: str = None dtype: torch.dtype = None scaling_factor: float = None shift_factor: float = None processor_cls: Union[VaeImageProcessor, VideoProcessor] = None output_pil_slice: torch.Tensor = None output_pt_slice: torch.Tensor = None partial_postprocess_return_pt_slice: torch.Tensor = None return_pt_slice: torch.Tensor = None width: int = None height: int = None def get_dummy_inputs(self): inputs = { "endpoint": self.endpoint, "tensor": torch.randn( self.shape, device=torch_device, dtype=self.dtype, generator=torch.Generator(torch_device).manual_seed(13), ), "scaling_factor": self.scaling_factor, "shift_factor": self.shift_factor, "height": self.height, "width": self.width, } return inputs def test_no_scaling(self): inputs = self.get_dummy_inputs() if inputs["scaling_factor"] is not None: inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"] inputs["scaling_factor"] = None if inputs["shift_factor"] is not None: inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"] inputs["shift_factor"] = None processor = self.processor_cls() output = remote_decode( output_type="pt", # required for now, will be removed in next update do_scaling=False, processor=processor, **inputs, ) assert isinstance(output, PIL.Image.Image) self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) # Increased tolerance for Flux Packed diff [1, 0, 1, 0, 0, 0, 0, 0, 0] self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), f"{output_slice}", ) def test_output_type_pt(self): inputs = self.get_dummy_inputs() processor = self.processor_cls() output = remote_decode(output_type="pt", processor=processor, **inputs) assert isinstance(output, PIL.Image.Image) self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}" ) # output is visually the same, slice is flaky? def test_output_type_pil(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pil", **inputs) self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") def test_output_type_pil_image_format(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pil", image_format="png", **inputs) self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") self.assertEqual(output.format, "png", f"Expected image format `png`, got {output.format}") output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}" ) def test_output_type_pt_partial_postprocess(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pt", partial_postprocess=True, **inputs) self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}" ) def test_output_type_pt_return_type_pt(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pt", return_type="pt", **inputs) self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}") self.assertEqual( output.shape[2], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}" ) self.assertEqual( output.shape[3], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}" ) output_slice = output[0, 0, -3:, -3:].flatten() self.assertTrue( torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3), f"{output_slice}", ) def test_output_type_pt_partial_postprocess_return_type_pt(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pt", partial_postprocess=True, return_type="pt", **inputs) self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}") self.assertEqual( output.shape[1], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[1]}" ) self.assertEqual( output.shape[2], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[2]}" ) output_slice = output[0, -3:, -3:, 0].flatten().cpu() self.assertTrue( torch_all_close(output_slice, self.partial_postprocess_return_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}", ) def test_do_scaling_deprecation(self): inputs = self.get_dummy_inputs() inputs.pop("scaling_factor", None) inputs.pop("shift_factor", None) with self.assertWarns(FutureWarning) as warning: _ = remote_decode(output_type="pt", partial_postprocess=True, **inputs) self.assertEqual( str(warning.warnings[0].message), "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", str(warning.warnings[0].message), ) def test_input_tensor_type_base64_deprecation(self): inputs = self.get_dummy_inputs() with self.assertWarns(FutureWarning) as warning: _ = remote_decode(output_type="pt", input_tensor_type="base64", partial_postprocess=True, **inputs) self.assertEqual( str(warning.warnings[0].message), "input_tensor_type='base64' is deprecated. Using `binary`.", str(warning.warnings[0].message), ) def test_output_tensor_type_base64_deprecation(self): inputs = self.get_dummy_inputs() with self.assertWarns(FutureWarning) as warning: _ = remote_decode(output_type="pt", output_tensor_type="base64", partial_postprocess=True, **inputs) self.assertEqual( str(warning.warnings[0].message), "output_tensor_type='base64' is deprecated. Using `binary`.", str(warning.warnings[0].message), ) class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin): def test_no_scaling(self): inputs = self.get_dummy_inputs() if inputs["scaling_factor"] is not None: inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"] inputs["scaling_factor"] = None if inputs["shift_factor"] is not None: inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"] inputs["shift_factor"] = None processor = self.processor_cls() output = remote_decode( output_type="pt", # required for now, will be removed in next update do_scaling=False, processor=processor, **inputs, ) self.assertTrue( isinstance(output, list) and isinstance(output[0], PIL.Image.Image), f"Expected `List[PIL.Image.Image]` output, got {type(output)}", ) self.assertEqual( output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" ) self.assertEqual( output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" ) output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), f"{output_slice}", ) def test_output_type_pt(self): inputs = self.get_dummy_inputs() processor = self.processor_cls() output = remote_decode(output_type="pt", processor=processor, **inputs) self.assertTrue( isinstance(output, list) and isinstance(output[0], PIL.Image.Image), f"Expected `List[PIL.Image.Image]` output, got {type(output)}", ) self.assertEqual( output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" ) self.assertEqual( output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" ) output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), f"{output_slice}", ) # output is visually the same, slice is flaky? def test_output_type_pil(self): inputs = self.get_dummy_inputs() processor = self.processor_cls() output = remote_decode(output_type="pil", processor=processor, **inputs) self.assertTrue( isinstance(output, list) and isinstance(output[0], PIL.Image.Image), f"Expected `List[PIL.Image.Image]` output, got {type(output)}", ) self.assertEqual( output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" ) self.assertEqual( output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" ) def test_output_type_pil_image_format(self): inputs = self.get_dummy_inputs() processor = self.processor_cls() output = remote_decode(output_type="pil", processor=processor, image_format="png", **inputs) self.assertTrue( isinstance(output, list) and isinstance(output[0], PIL.Image.Image), f"Expected `List[PIL.Image.Image]` output, got {type(output)}", ) self.assertEqual( output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" ) self.assertEqual( output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" ) output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), f"{output_slice}", ) def test_output_type_pt_partial_postprocess(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pt", partial_postprocess=True, **inputs) self.assertTrue( isinstance(output, list) and isinstance(output[0], PIL.Image.Image), f"Expected `List[PIL.Image.Image]` output, got {type(output)}", ) self.assertEqual( output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" ) self.assertEqual( output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" ) output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) self.assertTrue( torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), f"{output_slice}", ) def test_output_type_pt_return_type_pt(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="pt", return_type="pt", **inputs) self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}") self.assertEqual( output.shape[3], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}" ) self.assertEqual( output.shape[4], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}" ) output_slice = output[0, 0, 0, -3:, -3:].flatten() self.assertTrue( torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3), f"{output_slice}", ) def test_output_type_mp4(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="mp4", return_type="mp4", **inputs) self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}") class RemoteAutoencoderKLSDv1Tests( RemoteAutoencoderKLMixin, unittest.TestCase, ): shape = ( 1, 4, 64, 64, ) out_hw = ( 512, 512, ) endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None processor_cls = VaeImageProcessor output_pt_slice = torch.tensor([31, 15, 11, 55, 30, 21, 66, 42, 30], dtype=torch.uint8) partial_postprocess_return_pt_slice = torch.tensor([100, 130, 99, 133, 106, 112, 97, 100, 121], dtype=torch.uint8) return_pt_slice = torch.tensor([-0.2177, 0.0217, -0.2258, 0.0412, -0.1687, -0.1232, -0.2416, -0.2130, -0.0543]) class RemoteAutoencoderKLSDXLTests( RemoteAutoencoderKLMixin, unittest.TestCase, ): shape = ( 1, 4, 128, 128, ) out_hw = ( 1024, 1024, ) endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None processor_cls = VaeImageProcessor output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8) partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8) return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845]) class RemoteAutoencoderKLFluxTests( RemoteAutoencoderKLMixin, unittest.TestCase, ): shape = ( 1, 16, 128, 128, ) out_hw = ( 1024, 1024, ) endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 processor_cls = VaeImageProcessor output_pt_slice = torch.tensor([110, 72, 91, 62, 35, 52, 69, 55, 69], dtype=torch.uint8) partial_postprocess_return_pt_slice = torch.tensor( [202, 203, 203, 197, 195, 193, 189, 188, 178], dtype=torch.uint8 ) return_pt_slice = torch.tensor([0.5820, 0.5962, 0.5898, 0.5439, 0.5327, 0.5112, 0.4797, 0.4773, 0.3984]) class RemoteAutoencoderKLFluxPackedTests( RemoteAutoencoderKLMixin, unittest.TestCase, ): shape = ( 1, 4096, 64, ) out_hw = ( 1024, 1024, ) height = 1024 width = 1024 endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 processor_cls = VaeImageProcessor # slices are different due to randn on different shape. we can pack the latent instead if we want the same output_pt_slice = torch.tensor([96, 116, 157, 45, 67, 104, 34, 56, 89], dtype=torch.uint8) partial_postprocess_return_pt_slice = torch.tensor( [168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8 ) return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176]) class RemoteAutoencoderKLHunyuanVideoTests( RemoteAutoencoderKLHunyuanVideoMixin, unittest.TestCase, ): shape = ( 1, 16, 3, 40, 64, ) out_hw = ( 320, 512, ) endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO dtype = torch.float16 scaling_factor = 0.476986 processor_cls = VideoProcessor output_pt_slice = torch.tensor([112, 92, 85, 112, 93, 85, 112, 94, 85], dtype=torch.uint8) partial_postprocess_return_pt_slice = torch.tensor( [149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8 ) return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708]) class RemoteAutoencoderKLSlowTestMixin: channels: int = 4 endpoint: str = None dtype: torch.dtype = None scaling_factor: float = None shift_factor: float = None width: int = None height: int = None def get_dummy_inputs(self): inputs = { "endpoint": self.endpoint, "scaling_factor": self.scaling_factor, "shift_factor": self.shift_factor, "height": self.height, "width": self.width, } return inputs def test_multi_res(self): inputs = self.get_dummy_inputs() for height in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}: for width in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}: inputs["tensor"] = torch.randn( (1, self.channels, height // 8, width // 8), device=torch_device, dtype=self.dtype, generator=torch.Generator(torch_device).manual_seed(13), ) inputs["height"] = height inputs["width"] = width output = remote_decode(output_type="pt", partial_postprocess=True, **inputs) output.save(f"test_multi_res_{height}_{width}.png") @slow class RemoteAutoencoderKLSDv1SlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @slow class RemoteAutoencoderKLSDXLSlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @slow class RemoteAutoencoderKLFluxSlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): channels = 16 endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159