# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import os import warnings import torch # Add or edit model card to have `library_name: peft` def add_library_to_model_card(output_dir): if os.path.exists(os.path.join(output_dir, "README.md")): with open(os.path.join(output_dir, "README.md"), "r") as f: lines = f.readlines() # check if the first line is `---` if len(lines) > 0 and lines[0].startswith("---"): for i, line in enumerate(lines[1:]): # check if line starts with `library_name`, if yes, update it if line.startswith("library_name"): lines[i + 1] = "library_name: peft\n" break elif line.startswith("---"): # insert `library_name: peft` before the last `---` lines.insert(i + 1, "library_name: peft\n") break else: lines = ["---\n", "library_name: peft\n", "---\n"] + lines else: lines = ["---\n", "library_name: peft\n", "---\n"] # write the lines back to README.md with open(os.path.join(output_dir, "README.md"), "w") as f: f.writelines(lines) # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): past_key_values = torch.cat(past_key_values) total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape keys = past_key_values[: total_layers // 2] keys = keys.transpose(2, 3).reshape( total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens ) values = past_key_values[total_layers // 2 :] values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) return tuple(zip(keys, values)) def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): r""" This method wraps the entire protocol for preparing a model before running a training. This includes: 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm head to fp32 Args: model, (`transformers.PreTrainedModel`): The loaded model from `transformers` """ loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) for name, param in model.named_parameters(): # freeze base model's layers param.requires_grad = False # cast all non INT8 parameters to fp32 for param in model.parameters(): if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): param.data = param.data.to(torch.float32) if loaded_in_kbit and use_gradient_checkpointing: # For backward compatibility if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) # enable gradient checkpointing for memory efficiency model.gradient_checkpointing_enable() return model # For backward compatibility def prepare_model_for_int8_training(*args, **kwargs): warnings.warn( "prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.", FutureWarning, ) return prepare_model_for_kbit_training(*args, **kwargs) # copied from transformers.models.bart.modeling_bart def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids pad_token_id (`int`): The id of the `padding` token. decoder_start_token_id (`int`): The id of the `start` token. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids class ModulesToSaveWrapper(torch.nn.Module): def __init__(self, module_to_save, adapter_name): super().__init__() self.original_module = module_to_save self.modules_to_save = torch.nn.ModuleDict({}) self.update(adapter_name) self.active_adapter = adapter_name def update(self, adapter_name): self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) def forward(self, *args, **kwargs): if self.active_adapter not in self.modules_to_save: return self.original_module(*args, **kwargs) return self.modules_to_save[self.active_adapter](*args, **kwargs) def _get_submodules(model, key): parent = model.get_submodule(".".join(key.split(".")[:-1])) target_name = key.split(".")[-1] target = model.get_submodule(key) return parent, target, target_name def _freeze_adapter(model, adapter_name): for n, p in model.named_parameters(): if adapter_name in n: p.requires_grad = False def _set_trainable(model, adapter_name): key_list = [key for key, _ in model.named_modules()] for key in key_list: target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) if target_module_found: parent, target, target_name = _get_submodules(model, key) if isinstance(target, ModulesToSaveWrapper): target.update(adapter_name) else: for param in target.parameters(): param.requires_grad = True setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) def _set_adapter(model, adapter_name): for module in model.modules(): if isinstance(module, ModulesToSaveWrapper): module.active_adapter = adapter_name def fsdp_auto_wrap_policy(model): import functools import os from accelerate import FullyShardedDataParallelPlugin from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder def lambda_policy_fn(module): if ( len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and module.weight.requires_grad ): return True return False lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=( PrefixEncoder, PromptEncoder, PromptEmbedding, FullyShardedDataParallelPlugin.get_module_class_from_name( model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") ), ), ) auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) return auto_wrap_policy def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { "t5": ["q", "v"], "mt5": ["q", "v"], "bart": ["q_proj", "v_proj"], "gpt2": ["c_attn"], "bloom": ["query_key_value"], "blip-2": ["q", "v", "q_proj", "v_proj"], "opt": ["q_proj", "v_proj"], "gptj": ["q_proj", "v_proj"], "gpt_neox": ["query_key_value"], "gpt_neo": ["q_proj", "v_proj"], "bert": ["query", "value"], "roberta": ["query", "value"], "xlm-roberta": ["query", "value"], "electra": ["query", "value"], "deberta-v2": ["query_proj", "value_proj"], "deberta": ["in_proj"], "layoutlm": ["query", "value"], "llama": ["q_proj", "v_proj"], "chatglm": ["query_key_value"], "gpt_bigcode": ["c_attn"], "mpt": ["Wqkv"], } COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks"] TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { "t5": ["q", "k", "v", "o", "wi", "wo"], "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], # "gpt2": ["c_attn"], # "bloom": ["query_key_value"], "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], # "gptj": ["q_proj", "v_proj"], # "gpt_neox": ["query_key_value"], # "gpt_neo": ["q_proj", "v_proj"], # "bert": ["query", "value"], "roberta": ["query", "key", "value", "dense"], # "xlm-roberta": ["query", "value"], # "electra": ["query", "value"], "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], # "deberta": ["in_proj"], # "layoutlm": ["query", "value"], } TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { "bloom": bloom_model_postprocess_past_key_value, } WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json"