Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gc | |
import tempfile | |
import unittest | |
from typing import List | |
import numpy as np | |
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel | |
from diffusers import ( | |
AutoencoderKL, | |
FlowMatchEulerDiscreteScheduler, | |
FluxPipeline, | |
FluxTransformer2DModel, | |
TorchAoConfig, | |
) | |
from diffusers.models.attention_processor import Attention | |
from diffusers.utils.testing_utils import ( | |
enable_full_determinism, | |
is_torch_available, | |
is_torchao_available, | |
nightly, | |
require_torch, | |
require_torch_gpu, | |
require_torchao_version_greater_or_equal, | |
slow, | |
torch_device, | |
) | |
enable_full_determinism() | |
if is_torch_available(): | |
import torch | |
import torch.nn as nn | |
class LoRALayer(nn.Module): | |
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only | |
Taken from | |
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 | |
""" | |
def __init__(self, module: nn.Module, rank: int): | |
super().__init__() | |
self.module = module | |
self.adapter = nn.Sequential( | |
nn.Linear(module.in_features, rank, bias=False), | |
nn.Linear(rank, module.out_features, bias=False), | |
) | |
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 | |
nn.init.normal_(self.adapter[0].weight, std=small_std) | |
nn.init.zeros_(self.adapter[1].weight) | |
self.adapter.to(module.weight.device) | |
def forward(self, input, *args, **kwargs): | |
return self.module(input, *args, **kwargs) + self.adapter(input) | |
if is_torchao_available(): | |
from torchao.dtypes import AffineQuantizedTensor | |
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor | |
from torchao.utils import get_model_size_in_bytes | |
class TorchAoConfigTest(unittest.TestCase): | |
def test_to_dict(self): | |
""" | |
Makes sure the config format is properly set | |
""" | |
quantization_config = TorchAoConfig("int4_weight_only") | |
torchao_orig_config = quantization_config.to_dict() | |
for key in torchao_orig_config: | |
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) | |
def test_post_init_check(self): | |
""" | |
Test kwargs validations in TorchAoConfig | |
""" | |
_ = TorchAoConfig("int4_weight_only") | |
with self.assertRaisesRegex(ValueError, "is not supported yet"): | |
_ = TorchAoConfig("uint8") | |
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): | |
_ = TorchAoConfig("int4_weight_only", group_size1=32) | |
def test_repr(self): | |
""" | |
Check that there is no error in the repr | |
""" | |
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) | |
expected_repr = """TorchAoConfig { | |
"modules_to_not_convert": [ | |
"conv" | |
], | |
"quant_method": "torchao", | |
"quant_type": "int4_weight_only", | |
"quant_type_kwargs": { | |
"group_size": 8 | |
} | |
}""".replace(" ", "").replace("\n", "") | |
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") | |
self.assertEqual(quantization_repr, expected_repr) | |
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners | |
class TorchAoTest(unittest.TestCase): | |
def tearDown(self): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def get_dummy_components( | |
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" | |
): | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
) | |
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) | |
scheduler = FlowMatchEulerDiscreteScheduler() | |
return { | |
"scheduler": scheduler, | |
"text_encoder": text_encoder, | |
"text_encoder_2": text_encoder_2, | |
"tokenizer": tokenizer, | |
"tokenizer_2": tokenizer_2, | |
"transformer": transformer, | |
"vae": vae, | |
} | |
def get_dummy_inputs(self, device: torch.device, seed: int = 0): | |
if str(device).startswith("mps"): | |
generator = torch.manual_seed(seed) | |
else: | |
generator = torch.Generator().manual_seed(seed) | |
inputs = { | |
"prompt": "an astronaut riding a horse in space", | |
"height": 32, | |
"width": 32, | |
"num_inference_steps": 2, | |
"output_type": "np", | |
"generator": generator, | |
} | |
return inputs | |
def get_dummy_tensor_inputs(self, device=None, seed: int = 0): | |
batch_size = 1 | |
num_latent_channels = 4 | |
num_image_channels = 3 | |
height = width = 4 | |
sequence_length = 48 | |
embedding_dim = 32 | |
torch.manual_seed(seed) | |
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) | |
torch.manual_seed(seed) | |
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( | |
device, dtype=torch.bfloat16 | |
) | |
torch.manual_seed(seed) | |
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) | |
torch.manual_seed(seed) | |
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) | |
torch.manual_seed(seed) | |
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) | |
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) | |
return { | |
"hidden_states": hidden_states, | |
"encoder_hidden_states": encoder_hidden_states, | |
"pooled_projections": pooled_prompt_embeds, | |
"txt_ids": text_ids, | |
"img_ids": image_ids, | |
"timestep": timestep, | |
} | |
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str): | |
components = self.get_dummy_components(quantization_config, model_id) | |
pipe = FluxPipeline(**components) | |
pipe.to(device=torch_device) | |
inputs = self.get_dummy_inputs(torch_device) | |
output = pipe(**inputs)[0] | |
output_slice = output[-1, -1, -3:, -3:].flatten() | |
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |
def test_quantization(self): | |
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: | |
# fmt: off | |
QUANTIZATION_TYPES_TO_TEST = [ | |
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), | |
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), | |
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), | |
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), | |
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), | |
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), | |
] | |
if TorchAoConfig._is_cuda_capability_atleast_8_9(): | |
QUANTIZATION_TYPES_TO_TEST.extend([ | |
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), | |
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), | |
# ===== | |
# The following lead to an internal torch error: | |
# RuntimeError: mat2 shape (32x4 must be divisible by 16 | |
# Skip these for now; TODO(aryan): investigate later | |
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), | |
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), | |
# ===== | |
# Cutlass fails to initialize for below | |
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), | |
# ===== | |
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | |
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), | |
]) | |
# fmt: on | |
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: | |
quant_kwargs = {} | |
if quantization_name in ["uint4wo", "uint7wo"]: | |
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here | |
quant_kwargs.update({"group_size": 16}) | |
quantization_config = TorchAoConfig( | |
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs | |
) | |
self._test_quant_type(quantization_config, expected_slice, model_id) | |
def test_int4wo_quant_bfloat16_conversion(self): | |
""" | |
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. | |
""" | |
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) | |
quantized_model = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/tiny-flux-pipe", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
) | |
weight = quantized_model.transformer_blocks[0].ff.net[2].weight | |
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) | |
self.assertEqual(weight.quant_min, 0) | |
self.assertEqual(weight.quant_max, 15) | |
def test_device_map(self): | |
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did | |
# it would have errored out. Now, we do. So, device_map basically never worked with or without | |
# sharded checkpoints. This will need to be supported in the future (TODO(aryan)) | |
""" | |
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. | |
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is | |
correctly set (in the `hf_device_map` attribute of the model). | |
""" | |
custom_device_map_dict = { | |
"time_text_embed": torch_device, | |
"context_embedder": torch_device, | |
"x_embedder": torch_device, | |
"transformer_blocks.0": "cpu", | |
"single_transformer_blocks.0": "disk", | |
"norm_out": torch_device, | |
"proj_out": "cpu", | |
} | |
device_maps = ["auto", custom_device_map_dict] | |
# inputs = self.get_dummy_tensor_inputs(torch_device) | |
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) | |
for device_map in device_maps: | |
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map | |
# Test non-sharded model - should work | |
with self.assertRaises(NotImplementedError): | |
with tempfile.TemporaryDirectory() as offload_folder: | |
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) | |
_ = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/tiny-flux-pipe", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
device_map=device_map, | |
torch_dtype=torch.bfloat16, | |
offload_folder=offload_folder, | |
) | |
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight | |
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) | |
# self.assertTrue(isinstance(weight, AffineQuantizedTensor)) | |
# output = quantized_model(**inputs)[0] | |
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy() | |
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |
# Test sharded model - should not work | |
with self.assertRaises(NotImplementedError): | |
with tempfile.TemporaryDirectory() as offload_folder: | |
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) | |
_ = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/tiny-flux-sharded", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
device_map=device_map, | |
torch_dtype=torch.bfloat16, | |
offload_folder=offload_folder, | |
) | |
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight | |
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) | |
# self.assertTrue(isinstance(weight, AffineQuantizedTensor)) | |
# output = quantized_model(**inputs)[0] | |
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy() | |
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |
def test_modules_to_not_convert(self): | |
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) | |
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/tiny-flux-pipe", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
) | |
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] | |
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) | |
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) | |
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) | |
quantized_layer = quantized_model_with_not_convert.proj_out | |
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) | |
quantization_config = TorchAoConfig("int8_weight_only") | |
quantized_model = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/tiny-flux-pipe", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
) | |
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert) | |
size_quantized = get_model_size_in_bytes(quantized_model) | |
self.assertTrue(size_quantized < size_quantized_with_not_convert) | |
def test_training(self): | |
quantization_config = TorchAoConfig("int8_weight_only") | |
quantized_model = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/tiny-flux-pipe", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
).to(torch_device) | |
for param in quantized_model.parameters(): | |
# freeze the model as only adapter layers will be trained | |
param.requires_grad = False | |
if param.ndim == 1: | |
param.data = param.data.to(torch.float32) | |
for _, module in quantized_model.named_modules(): | |
if isinstance(module, Attention): | |
module.to_q = LoRALayer(module.to_q, rank=4) | |
module.to_k = LoRALayer(module.to_k, rank=4) | |
module.to_v = LoRALayer(module.to_v, rank=4) | |
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): | |
inputs = self.get_dummy_tensor_inputs(torch_device) | |
output = quantized_model(**inputs)[0] | |
output.norm().backward() | |
for module in quantized_model.modules(): | |
if isinstance(module, LoRALayer): | |
self.assertTrue(module.adapter[1].weight.grad is not None) | |
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) | |
def test_torch_compile(self): | |
r"""Test that verifies if torch.compile works with torchao quantization.""" | |
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: | |
quantization_config = TorchAoConfig("int8_weight_only") | |
components = self.get_dummy_components(quantization_config, model_id=model_id) | |
pipe = FluxPipeline(**components) | |
pipe.to(device=torch_device) | |
inputs = self.get_dummy_inputs(torch_device) | |
normal_output = pipe(**inputs)[0].flatten()[-32:] | |
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) | |
inputs = self.get_dummy_inputs(torch_device) | |
compile_output = pipe(**inputs)[0].flatten()[-32:] | |
# Note: Seems to require higher tolerance | |
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) | |
def test_memory_footprint(self): | |
r""" | |
A simple test to check if the model conversion has been done correctly by checking on the | |
memory footprint of the converted model and the class type of the linear layers of the converted models | |
""" | |
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: | |
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"] | |
transformer_int4wo_gs32 = self.get_dummy_components( | |
TorchAoConfig("int4wo", group_size=32), model_id=model_id | |
)["transformer"] | |
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] | |
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] | |
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64 | |
for block in transformer_int4wo.transformer_blocks: | |
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor)) | |
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor)) | |
# Will quantize all the linear layers except x_embedder | |
for name, module in transformer_int4wo_gs32.named_modules(): | |
if isinstance(module, nn.Linear) and name not in ["x_embedder"]: | |
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) | |
# Will quantize all the linear layers | |
for module in transformer_int8wo.modules(): | |
if isinstance(module, nn.Linear): | |
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) | |
total_int4wo = get_model_size_in_bytes(transformer_int4wo) | |
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) | |
total_int8wo = get_model_size_in_bytes(transformer_int8wo) | |
total_bf16 = get_model_size_in_bytes(transformer_bf16) | |
# TODO: refactor to align with other quantization tests | |
# Latter has smaller group size, so more groups -> more scales and zero points | |
self.assertTrue(total_int4wo < total_int4wo_gs32) | |
# int8 quantizes more layers compare to int4 with default group size | |
self.assertTrue(total_int8wo < total_int4wo) | |
# int4wo does not quantize too many layers because of default group size, but for the layers it does | |
# there is additional overhead of scales and zero points | |
self.assertTrue(total_bf16 < total_int4wo) | |
def test_wrong_config(self): | |
with self.assertRaises(ValueError): | |
self.get_dummy_components(TorchAoConfig("int42")) | |
def test_sequential_cpu_offload(self): | |
r""" | |
A test that checks if inference runs as expected when sequential cpu offloading is enabled. | |
""" | |
quantization_config = TorchAoConfig("int8wo") | |
components = self.get_dummy_components(quantization_config) | |
pipe = FluxPipeline(**components) | |
pipe.enable_sequential_cpu_offload() | |
inputs = self.get_dummy_inputs(torch_device) | |
_ = pipe(**inputs) | |
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners | |
class TorchAoSerializationTest(unittest.TestCase): | |
model_name = "hf-internal-testing/tiny-flux-pipe" | |
def tearDown(self): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): | |
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) | |
quantized_model = FluxTransformer2DModel.from_pretrained( | |
self.model_name, | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
) | |
return quantized_model.to(device) | |
def get_dummy_tensor_inputs(self, device=None, seed: int = 0): | |
batch_size = 1 | |
num_latent_channels = 4 | |
num_image_channels = 3 | |
height = width = 4 | |
sequence_length = 48 | |
embedding_dim = 32 | |
torch.manual_seed(seed) | |
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) | |
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( | |
device, dtype=torch.bfloat16 | |
) | |
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) | |
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) | |
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) | |
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) | |
return { | |
"hidden_states": hidden_states, | |
"encoder_hidden_states": encoder_hidden_states, | |
"pooled_projections": pooled_prompt_embeds, | |
"txt_ids": text_ids, | |
"img_ids": image_ids, | |
"timestep": timestep, | |
} | |
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice): | |
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device) | |
inputs = self.get_dummy_tensor_inputs(torch_device) | |
output = quantized_model(**inputs)[0] | |
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() | |
weight = quantized_model.transformer_blocks[0].ff.net[2].weight | |
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) | |
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): | |
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
quantized_model.save_pretrained(tmp_dir, safe_serialization=False) | |
loaded_quantized_model = FluxTransformer2DModel.from_pretrained( | |
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False | |
).to(device=torch_device) | |
inputs = self.get_dummy_tensor_inputs(torch_device) | |
output = loaded_quantized_model(**inputs)[0] | |
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() | |
self.assertTrue( | |
isinstance( | |
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) | |
) | |
) | |
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |
def test_int_a8w8_cuda(self): | |
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} | |
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) | |
device = "cuda" | |
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) | |
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) | |
def test_int_a16w8_cuda(self): | |
quant_method, quant_method_kwargs = "int8_weight_only", {} | |
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) | |
device = "cuda" | |
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) | |
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) | |
def test_int_a8w8_cpu(self): | |
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} | |
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) | |
device = "cpu" | |
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) | |
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) | |
def test_int_a16w8_cpu(self): | |
quant_method, quant_method_kwargs = "int8_weight_only", {} | |
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) | |
device = "cpu" | |
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) | |
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) | |
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners | |
class SlowTorchAoTests(unittest.TestCase): | |
def tearDown(self): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def get_dummy_components(self, quantization_config: TorchAoConfig): | |
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing | |
cache_dir = None | |
model_id = "black-forest-labs/FLUX.1-dev" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
cache_dir=cache_dir, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir | |
) | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) | |
scheduler = FlowMatchEulerDiscreteScheduler() | |
return { | |
"scheduler": scheduler, | |
"text_encoder": text_encoder, | |
"text_encoder_2": text_encoder_2, | |
"tokenizer": tokenizer, | |
"tokenizer_2": tokenizer_2, | |
"transformer": transformer, | |
"vae": vae, | |
} | |
def get_dummy_inputs(self, device: torch.device, seed: int = 0): | |
if str(device).startswith("mps"): | |
generator = torch.manual_seed(seed) | |
else: | |
generator = torch.Generator().manual_seed(seed) | |
inputs = { | |
"prompt": "an astronaut riding a horse in space", | |
"height": 512, | |
"width": 512, | |
"num_inference_steps": 20, | |
"output_type": "np", | |
"generator": generator, | |
} | |
return inputs | |
def _test_quant_type(self, quantization_config, expected_slice): | |
components = self.get_dummy_components(quantization_config) | |
pipe = FluxPipeline(**components) | |
pipe.enable_model_cpu_offload() | |
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight | |
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) | |
inputs = self.get_dummy_inputs(torch_device) | |
output = pipe(**inputs)[0].flatten() | |
output_slice = np.concatenate((output[:16], output[-16:])) | |
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |
def test_quantization(self): | |
# fmt: off | |
QUANTIZATION_TYPES_TO_TEST = [ | |
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), | |
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), | |
] | |
if TorchAoConfig._is_cuda_capability_atleast_8_9(): | |
QUANTIZATION_TYPES_TO_TEST.extend([ | |
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), | |
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), | |
]) | |
# fmt: on | |
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: | |
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) | |
self._test_quant_type(quantization_config, expected_slice) | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
def test_serialization_int8wo(self): | |
quantization_config = TorchAoConfig("int8wo") | |
components = self.get_dummy_components(quantization_config) | |
pipe = FluxPipeline(**components) | |
pipe.enable_model_cpu_offload() | |
weight = pipe.transformer.x_embedder.weight | |
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) | |
inputs = self.get_dummy_inputs(torch_device) | |
output = pipe(**inputs)[0].flatten()[:128] | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False) | |
pipe.remove_all_hooks() | |
del pipe.transformer | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
transformer = FluxTransformer2DModel.from_pretrained( | |
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False | |
) | |
pipe.transformer = transformer | |
pipe.enable_model_cpu_offload() | |
weight = transformer.x_embedder.weight | |
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) | |
loaded_output = pipe(**inputs)[0].flatten()[:128] | |
# Seems to require higher tolerance depending on which machine it is being run. | |
# A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of | |
# 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04, | |
# on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here. | |
self.assertTrue(np.allclose(output, loaded_output, atol=0.06)) | |
def test_memory_footprint_int4wo(self): | |
# The original checkpoints are in bf16 and about 24 GB | |
expected_memory_in_gb = 6.0 | |
quantization_config = TorchAoConfig("int4wo") | |
cache_dir = None | |
transformer = FluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
cache_dir=cache_dir, | |
) | |
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 | |
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb) | |
def test_memory_footprint_int8wo(self): | |
# The original checkpoints are in bf16 and about 24 GB | |
expected_memory_in_gb = 12.0 | |
quantization_config = TorchAoConfig("int8wo") | |
cache_dir = None | |
transformer = FluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="transformer", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
cache_dir=cache_dir, | |
) | |
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 | |
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb) | |
class SlowTorchAoPreserializedModelTests(unittest.TestCase): | |
def tearDown(self): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def get_dummy_inputs(self, device: torch.device, seed: int = 0): | |
if str(device).startswith("mps"): | |
generator = torch.manual_seed(seed) | |
else: | |
generator = torch.Generator().manual_seed(seed) | |
inputs = { | |
"prompt": "an astronaut riding a horse in space", | |
"height": 512, | |
"width": 512, | |
"num_inference_steps": 20, | |
"output_type": "np", | |
"generator": generator, | |
} | |
return inputs | |
def test_transformer_int8wo(self): | |
# fmt: off | |
expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703]) | |
# fmt: on | |
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing | |
cache_dir = None | |
transformer = FluxTransformer2DModel.from_pretrained( | |
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer", | |
torch_dtype=torch.bfloat16, | |
use_safetensors=False, | |
cache_dir=cache_dir, | |
) | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir | |
) | |
pipe.enable_model_cpu_offload() | |
# Verify that all linear layer weights are quantized | |
for name, module in pipe.transformer.named_modules(): | |
if isinstance(module, nn.Linear): | |
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) | |
# Verify outputs match expected slice | |
inputs = self.get_dummy_inputs(torch_device) | |
output = pipe(**inputs)[0].flatten() | |
output_slice = np.concatenate((output[:16], output[-16:])) | |
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) | |