multimodalart's picture
Upload 2025 files
22a452a verified
raw
history blame
12.5 kB
import gc
import tempfile
import unittest
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from diffusers.models.attention_processor import Attention
from diffusers.utils import is_optimum_quanto_available, is_torch_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_reset_peak_memory_stats,
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_big_accelerator,
require_torch_cuda_compatibility,
torch_device,
)
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_torch_available():
import torch
from ..utils import LoRALayer, get_memory_consumption_stat
enable_full_determinism()
@nightly
@require_big_accelerator
@require_accelerate
class QuantoBaseTesterMixin:
model_id = None
pipeline_model_id = None
model_cls = None
torch_dtype = torch.bfloat16
# the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
expected_memory_reduction = 0.0
keep_in_fp32_module = ""
modules_to_not_convert = ""
_test_torch_compile = False
def setUp(self):
backend_reset_peak_memory_stats(torch_device)
backend_empty_cache(torch_device)
gc.collect()
def tearDown(self):
backend_reset_peak_memory_stats(torch_device)
backend_empty_cache(torch_device)
gc.collect()
def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"}
def get_dummy_model_init_kwargs(self):
return {
"pretrained_model_name_or_path": self.model_id,
"torch_dtype": self.torch_dtype,
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
}
def test_quanto_layers(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
assert isinstance(module, QLinear)
def test_quanto_memory_usage(self):
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
}
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
quantized_model.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
Also ensures if inference works.
"""
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
model.to(torch_device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
assert module.weight.dtype == torch.float32
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
def test_modules_to_not_convert(self):
init_kwargs = self.get_dummy_model_init_kwargs()
quantization_config_kwargs = self.get_dummy_init_kwargs()
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
quantization_config = QuantoConfig(**quantization_config_kwargs)
init_kwargs.update({"quantization_config": quantization_config})
model = self.model_cls.from_pretrained(**init_kwargs)
model.to(torch_device)
for name, module in model.named_modules():
if name in self.modules_to_not_convert:
assert not isinstance(module, QLinear)
def test_dtype_assignment(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
with self.assertRaises(ValueError):
# Tries with a `dtype`
model.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a cast
model.float()
with self.assertRaises(ValueError):
# Tries with a cast
model.half()
# This should work
model.to(torch_device)
def test_serialization(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs()
model.to(torch_device)
with torch.no_grad():
model_output = model(**inputs)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
saved_model = self.model_cls.from_pretrained(
tmp_dir,
torch_dtype=torch.bfloat16,
)
saved_model.to(torch_device)
with torch.no_grad():
saved_model_output = saved_model(**inputs)
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
def test_torch_compile(self):
if not self._test_torch_compile:
return
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
model.to(torch_device)
with torch.no_grad():
model_output = model(**self.get_dummy_inputs()).sample
compiled_model.to(torch_device)
with torch.no_grad():
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
model_output = model_output.detach().float().cpu().numpy()
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
assert max_diff < 1e-3
def test_device_map_error(self):
with self.assertRaises(ValueError):
_ = self.model_cls.from_pretrained(
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
)
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
model_id = "hf-internal-testing/tiny-flux-transformer"
model_cls = FluxTransformer2DModel
pipeline_cls = FluxPipeline
torch_dtype = torch.bfloat16
keep_in_fp32_module = "proj_out"
modules_to_not_convert = ["proj_out"]
_test_torch_compile = False
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_projections": torch.randn(
(1, 768),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
}
def get_dummy_training_inputs(self, device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
def test_model_cpu_offload(self):
init_kwargs = self.get_dummy_init_kwargs()
transformer = self.model_cls.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
quantization_config=QuantoConfig(**init_kwargs),
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
pipe = self.pipeline_cls.from_pretrained(
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=torch_device)
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
def test_training(self):
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
quantized_model = self.model_cls.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for param in quantized_model.parameters():
# freeze the model as only adapter layers will be trained
param.requires_grad = False
if param.ndim == 1:
param.data = param.data.to(torch.float32)
for _, module in quantized_model.named_modules():
if isinstance(module, Attention):
module.to_q = LoRALayer(module.to_q, rank=4)
module.to_k = LoRALayer(module.to_k, rank=4)
module.to_v = LoRALayer(module.to_v, rank=4)
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
inputs = self.get_dummy_training_inputs(torch_device)
output = quantized_model(**inputs)[0]
output.norm().backward()
for module in quantized_model.modules():
if isinstance(module, LoRALayer):
self.assertTrue(module.adapter[1].weight.grad is not None)
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.6
def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"}
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.6
_test_torch_compile = True
def get_dummy_init_kwargs(self):
return {"weights_dtype": "int8"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.55
def get_dummy_init_kwargs(self):
return {"weights_dtype": "int4"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.65
def get_dummy_init_kwargs(self):
return {"weights_dtype": "int2"}