# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. "HQQ (Half-Quadratic Quantization) integration file" from ..utils import is_hqq_available, is_torch_available, logging if is_torch_available(): import torch logger = logging.get_logger(__name__) # Name all modules inside the model def autoname_modules(model): for name, module in model.named_modules(): module.name = name # Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj def name_to_linear_tag(name): return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))]) # Get all linear tags available def get_linear_tags(model): if is_hqq_available(): from hqq.core.quantize import HQQLinear linear_tags = set() for name, module in model.named_modules(): if isinstance(module, (torch.nn.Linear, HQQLinear)): linear_tags.add(name_to_linear_tag(name)) return list(linear_tags) def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None): for name, module in model.named_children(): if current_key_name is None: current_key_name = [] current_key_name.append(name) if isinstance(module, torch.nn.Linear): # Get linear tag linear_tag = name_to_linear_tag(module.name) # We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param() if linear_tag in patch_params: if patch_params[linear_tag] is not None: model._modules[name].quant_config = patch_params[linear_tag] # Store the module class in case we need to transpose the weight later model._modules[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) has_been_replaced = True if len(list(module.children())) > 0: _, has_been_replaced = _prepare_for_hqq_linear( module, patch_params=patch_params, has_been_replaced=has_been_replaced, ) # Remove the last key for recursion current_key_name.pop(-1) return model, has_been_replaced def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False): """ Prepares nn.Linear layers for HQQ quantization. Since each layer type can have separate quantization parameters, we need to do the following: 1- tag each module with its neme via autoname_modules() 2- Extract linear_tags (e.g. ['self_attn.q_proj', ...]) 3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params """ modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert # Add name to module autoname_modules(model) # Get linear tags. This allows us to use different quant params to different layer types linear_tags = get_linear_tags(model) # Convert quantization_config to layer-wise config skip_modules = quantization_config.skip_modules quant_config = quantization_config.to_dict() linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) if any(key in linear_tags for key in quant_config.keys()): # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None) patch_params = {key: None for key in linear_tags} patch_params.update(quant_config) else: # Same quant_config for all layers patch_params = {k: quant_config for k in linear_tags} model, has_been_replaced = _prepare_for_hqq_linear( model, patch_params=patch_params, has_been_replaced=has_been_replaced ) # We store quantization config as linear_tag -> hqq quant config model.config.quantization_config = patch_params if not has_been_replaced: logger.warning("No linear modules were found in your model for quantization.") return model