Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2025 HuggingFace Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
import PIL.Image | |
import torch | |
from diffusers.utils import load_image | |
from diffusers.utils.constants import ( | |
DECODE_ENDPOINT_FLUX, | |
DECODE_ENDPOINT_SD_V1, | |
DECODE_ENDPOINT_SD_XL, | |
ENCODE_ENDPOINT_FLUX, | |
ENCODE_ENDPOINT_SD_V1, | |
ENCODE_ENDPOINT_SD_XL, | |
) | |
from diffusers.utils.remote_utils import ( | |
remote_decode, | |
remote_encode, | |
) | |
from diffusers.utils.testing_utils import ( | |
enable_full_determinism, | |
slow, | |
) | |
enable_full_determinism() | |
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" | |
class RemoteAutoencoderKLEncodeMixin: | |
channels: int = None | |
endpoint: str = None | |
decode_endpoint: str = None | |
dtype: torch.dtype = None | |
scaling_factor: float = None | |
shift_factor: float = None | |
image: PIL.Image.Image = None | |
def get_dummy_inputs(self): | |
if self.image is None: | |
self.image = load_image(IMAGE) | |
inputs = { | |
"endpoint": self.endpoint, | |
"image": self.image, | |
"scaling_factor": self.scaling_factor, | |
"shift_factor": self.shift_factor, | |
} | |
return inputs | |
def test_image_input(self): | |
inputs = self.get_dummy_inputs() | |
height, width = inputs["image"].height, inputs["image"].width | |
output = remote_encode(**inputs) | |
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) | |
decoded = remote_decode( | |
tensor=output, | |
endpoint=self.decode_endpoint, | |
scaling_factor=self.scaling_factor, | |
shift_factor=self.shift_factor, | |
image_format="png", | |
) | |
self.assertEqual(decoded.height, height) | |
self.assertEqual(decoded.width, width) | |
# image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten()) | |
# decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten()) | |
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? | |
class RemoteAutoencoderKLSDv1Tests( | |
RemoteAutoencoderKLEncodeMixin, | |
unittest.TestCase, | |
): | |
channels = 4 | |
endpoint = ENCODE_ENDPOINT_SD_V1 | |
decode_endpoint = DECODE_ENDPOINT_SD_V1 | |
dtype = torch.float16 | |
scaling_factor = 0.18215 | |
shift_factor = None | |
class RemoteAutoencoderKLSDXLTests( | |
RemoteAutoencoderKLEncodeMixin, | |
unittest.TestCase, | |
): | |
channels = 4 | |
endpoint = ENCODE_ENDPOINT_SD_XL | |
decode_endpoint = DECODE_ENDPOINT_SD_XL | |
dtype = torch.float16 | |
scaling_factor = 0.13025 | |
shift_factor = None | |
class RemoteAutoencoderKLFluxTests( | |
RemoteAutoencoderKLEncodeMixin, | |
unittest.TestCase, | |
): | |
channels = 16 | |
endpoint = ENCODE_ENDPOINT_FLUX | |
decode_endpoint = DECODE_ENDPOINT_FLUX | |
dtype = torch.bfloat16 | |
scaling_factor = 0.3611 | |
shift_factor = 0.1159 | |
class RemoteAutoencoderKLEncodeSlowTestMixin: | |
channels: int = 4 | |
endpoint: str = None | |
decode_endpoint: str = None | |
dtype: torch.dtype = None | |
scaling_factor: float = None | |
shift_factor: float = None | |
image: PIL.Image.Image = None | |
def get_dummy_inputs(self): | |
if self.image is None: | |
self.image = load_image(IMAGE) | |
inputs = { | |
"endpoint": self.endpoint, | |
"image": self.image, | |
"scaling_factor": self.scaling_factor, | |
"shift_factor": self.shift_factor, | |
} | |
return inputs | |
def test_multi_res(self): | |
inputs = self.get_dummy_inputs() | |
for height in { | |
320, | |
512, | |
640, | |
704, | |
896, | |
1024, | |
1208, | |
1384, | |
1536, | |
1608, | |
1864, | |
2048, | |
}: | |
for width in { | |
320, | |
512, | |
640, | |
704, | |
896, | |
1024, | |
1208, | |
1384, | |
1536, | |
1608, | |
1864, | |
2048, | |
}: | |
inputs["image"] = inputs["image"].resize( | |
( | |
width, | |
height, | |
) | |
) | |
output = remote_encode(**inputs) | |
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) | |
decoded = remote_decode( | |
tensor=output, | |
endpoint=self.decode_endpoint, | |
scaling_factor=self.scaling_factor, | |
shift_factor=self.shift_factor, | |
image_format="png", | |
) | |
self.assertEqual(decoded.height, height) | |
self.assertEqual(decoded.width, width) | |
decoded.save(f"test_multi_res_{height}_{width}.png") | |
class RemoteAutoencoderKLSDv1SlowTests( | |
RemoteAutoencoderKLEncodeSlowTestMixin, | |
unittest.TestCase, | |
): | |
endpoint = ENCODE_ENDPOINT_SD_V1 | |
decode_endpoint = DECODE_ENDPOINT_SD_V1 | |
dtype = torch.float16 | |
scaling_factor = 0.18215 | |
shift_factor = None | |
class RemoteAutoencoderKLSDXLSlowTests( | |
RemoteAutoencoderKLEncodeSlowTestMixin, | |
unittest.TestCase, | |
): | |
endpoint = ENCODE_ENDPOINT_SD_XL | |
decode_endpoint = DECODE_ENDPOINT_SD_XL | |
dtype = torch.float16 | |
scaling_factor = 0.13025 | |
shift_factor = None | |
class RemoteAutoencoderKLFluxSlowTests( | |
RemoteAutoencoderKLEncodeSlowTestMixin, | |
unittest.TestCase, | |
): | |
channels = 16 | |
endpoint = ENCODE_ENDPOINT_FLUX | |
decode_endpoint = DECODE_ENDPOINT_FLUX | |
dtype = torch.bfloat16 | |
scaling_factor = 0.3611 | |
shift_factor = 0.1159 | |