Cosmos-Predict2 / diffusers_repo /tests /quantization /test_pipeline_level_quantization.py
multimodalart's picture
Upload 2025 files
22a452a verified
raw
history blame
8.01 kB
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from diffusers import DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
is_transformers_available,
require_accelerate,
require_bitsandbytes_version_greater,
require_quanto,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
if is_transformers_available():
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig
else:
TranBitsAndBytesConfig = None
@require_bitsandbytes_version_greater("0.43.2")
@require_quanto
@require_accelerate
@require_torch
@require_torch_accelerator
@slow
class PipelineQuantizationTests(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
def test_quant_config_set_correctly_through_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=components_to_quantize,
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_quant_config_set_correctly_through_granular(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_raises_error_for_invalid_config(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
},
quant_backend="bitsandbytes_4bit",
)
self.assertTrue(
str(err_context.exception)
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time."
)
def test_validation_for_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_backend="quanto",
quant_kwargs={"weights_dtype": "int8"},
components_to_quantize=components_to_quantize,
)
self.assertTrue(
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
)
def test_raises_error_for_wrong_config_class(self):
quant_config = {
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
with self.assertRaises(ValueError) as err_context:
_ = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
)
def test_validation_for_mapping(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusionPipeline(),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception))
def test_saving_loading(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"}
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device)
for name, component in loaded_pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
self.assertTrue(torch.allclose(output_1, output_2))