multimodalart's picture
Upload 2025 files
22a452a verified
raw
history blame
6.35 kB
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import PIL.Image
import torch
from diffusers.utils import load_image
from diffusers.utils.constants import (
DECODE_ENDPOINT_FLUX,
DECODE_ENDPOINT_SD_V1,
DECODE_ENDPOINT_SD_XL,
ENCODE_ENDPOINT_FLUX,
ENCODE_ENDPOINT_SD_V1,
ENCODE_ENDPOINT_SD_XL,
)
from diffusers.utils.remote_utils import (
remote_decode,
remote_encode,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
)
enable_full_determinism()
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
class RemoteAutoencoderKLEncodeMixin:
channels: int = None
endpoint: str = None
decode_endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
image: PIL.Image.Image = None
def get_dummy_inputs(self):
if self.image is None:
self.image = load_image(IMAGE)
inputs = {
"endpoint": self.endpoint,
"image": self.image,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
}
return inputs
def test_image_input(self):
inputs = self.get_dummy_inputs()
height, width = inputs["image"].height, inputs["image"].width
output = remote_encode(**inputs)
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
decoded = remote_decode(
tensor=output,
endpoint=self.decode_endpoint,
scaling_factor=self.scaling_factor,
shift_factor=self.shift_factor,
image_format="png",
)
self.assertEqual(decoded.height, height)
self.assertEqual(decoded.width, width)
# image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
# decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
class RemoteAutoencoderKLSDv1Tests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 4
endpoint = ENCODE_ENDPOINT_SD_V1
decode_endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
class RemoteAutoencoderKLSDXLTests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 4
endpoint = ENCODE_ENDPOINT_SD_XL
decode_endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
class RemoteAutoencoderKLFluxTests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENCODE_ENDPOINT_FLUX
decode_endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
class RemoteAutoencoderKLEncodeSlowTestMixin:
channels: int = 4
endpoint: str = None
decode_endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
image: PIL.Image.Image = None
def get_dummy_inputs(self):
if self.image is None:
self.image = load_image(IMAGE)
inputs = {
"endpoint": self.endpoint,
"image": self.image,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
}
return inputs
def test_multi_res(self):
inputs = self.get_dummy_inputs()
for height in {
320,
512,
640,
704,
896,
1024,
1208,
1384,
1536,
1608,
1864,
2048,
}:
for width in {
320,
512,
640,
704,
896,
1024,
1208,
1384,
1536,
1608,
1864,
2048,
}:
inputs["image"] = inputs["image"].resize(
(
width,
height,
)
)
output = remote_encode(**inputs)
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
decoded = remote_decode(
tensor=output,
endpoint=self.decode_endpoint,
scaling_factor=self.scaling_factor,
shift_factor=self.shift_factor,
image_format="png",
)
self.assertEqual(decoded.height, height)
self.assertEqual(decoded.width, width)
decoded.save(f"test_multi_res_{height}_{width}.png")
@slow
class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
endpoint = ENCODE_ENDPOINT_SD_V1
decode_endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@slow
class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
endpoint = ENCODE_ENDPOINT_SD_XL
decode_endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@slow
class RemoteAutoencoderKLFluxSlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENCODE_ENDPOINT_FLUX
decode_endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159