# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import inspect import os import warnings from contextlib import contextmanager from copy import deepcopy from typing import Any, Dict, Optional, Union import torch from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.utils import PushToHubMixin from . import __version__ from .tuners import ( AdaLoraModel, AdaptionPromptModel, LoraModel, PrefixEncoder, PromptEmbedding, PromptEncoder, MoeLoraModel, ) from .utils import ( SAFETENSORS_WEIGHTS_NAME, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, WEIGHTS_NAME, PeftConfig, PeftType, PromptLearningConfig, TaskType, _set_adapter, _set_trainable, add_library_to_model_card, get_peft_model_state_dict, hub_file_exists, set_peft_model_state_dict, shift_tokens_right, ) PEFT_TYPE_TO_MODEL_MAPPING = { PeftType.LORA: LoraModel, PeftType.PROMPT_TUNING: PromptEmbedding, PeftType.P_TUNING: PromptEncoder, PeftType.PREFIX_TUNING: PrefixEncoder, PeftType.ADALORA: AdaLoraModel, PeftType.ADAPTION_PROMPT: AdaptionPromptModel, PeftType.MOELORA: MoeLoraModel, } class PeftModel(PushToHubMixin, torch.nn.Module): """ Base model encompassing various Peft methods. Args: model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft. peft_config ([`PeftConfig`]): The configuration of the Peft model. **Attributes**: - **base_model** ([`~transformers.PreTrainedModel`]) -- The base transformer model used for Peft. - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model. - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when saving the model. - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if using [`PromptLearningConfig`]. - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if using [`PromptLearningConfig`]. - **transformer_backbone_name** (`str`) -- The name of the transformer backbone in the base model if using [`PromptLearningConfig`]. - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone in the base model if using [`PromptLearningConfig`]. """ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"): super().__init__() self.base_model = model self.config = self.base_model.config self.modules_to_save = None self.peft_config = {} self.active_adapter = adapter_name self.peft_type = peft_config.peft_type self.base_model_torch_dtype = getattr(model, "dtype", None) if not isinstance(peft_config, PromptLearningConfig): self.peft_config[adapter_name] = peft_config self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]( self.base_model, self.peft_config, adapter_name ) self.set_additional_trainable_modules(peft_config, adapter_name) else: self.add_adapter(adapter_name, peft_config) if getattr(model, "is_gradient_checkpointing", True): model = self._prepare_model_for_gradient_checkpointing(model) def save_pretrained(self, save_directory: str, safe_serialization: bool = False, **kwargs: Any): r""" This function saves the adapter model and the adapter configuration files to a directory, so that it can be reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`] method. Args: save_directory (`str`): Directory where the adapter model and configuration files will be saved (will be created if it does not exist). kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. """ if os.path.isfile(save_directory): raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") os.makedirs(save_directory, exist_ok=True) self.create_or_update_model_card(save_directory) for adapter_name, peft_config in self.peft_config.items(): # save only the trainable weights output_state_dict = get_peft_model_state_dict( self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name ) output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory os.makedirs(output_dir, exist_ok=True) if safe_serialization: safe_save_file( output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"} ) else: torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) # save the config and change the inference mode to `True` if peft_config.base_model_name_or_path is None: peft_config.base_model_name_or_path = ( self.base_model.__dict__.get("name_or_path", None) if isinstance(peft_config, PromptLearningConfig) else self.base_model.model.__dict__.get("name_or_path", None) ) inference_mode = peft_config.inference_mode peft_config.inference_mode = True peft_config.save_pretrained(output_dir) peft_config.inference_mode = inference_mode @classmethod def from_pretrained( cls, model: PreTrainedModel, model_id: Union[str, os.PathLike], adapter_name: str = "default", is_trainable: bool = False, config: Optional[PeftConfig] = None, **kwargs: Any, ): r""" Instantiate a [`LoraModel`] from a pretrained Lora configuration and weights. Args: model ([`~transformers.PreTrainedModel`]): The model to be adapted. The model should be initialized with the [`~transformers.PreTrainedModel.from_pretrained`] method from the 🤗 Transformers library. model_id (`str` or `os.PathLike`): The name of the Lora configuration to use. Can be either: - A string, the `model id` of a Lora configuration hosted inside a model repo on the Hugging Face Hub. - A path to a directory containing a Lora configuration file saved using the `save_pretrained` method (`./my_lora_config_directory/`). adapter_name (`str`, *optional*, defaults to `"default"`): The name of the adapter to be loaded. This is useful for loading multiple adapters. is_trainable (`bool`, *optional*, defaults to `False`): Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for inference config ([`~peft.PeftConfig`], *optional*): The configuration object to use instead of an automatically loaded configuation. This configuration object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already loaded before calling `from_pretrained`. kwargs: (`optional`): Additional keyword arguments passed along to the specific Lora configuration class. """ from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING # load the config if config is None: config = PEFT_TYPE_TO_CONFIG_MAPPING[ PeftConfig._get_peft_type( model_id, subfolder=kwargs.get("subfolder", None), revision=kwargs.get("revision", None), cache_dir=kwargs.get("cache_dir", None), ) ].from_pretrained(model_id, subfolder=kwargs.get("subfolder", None), **kwargs) elif isinstance(config, PeftConfig): config.inference_mode = not is_trainable else: raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") if (getattr(model, "hf_device_map", None) is not None) and len( set(model.hf_device_map.values()).intersection({"cpu", "disk"}) ) > 0: remove_hook_from_submodules(model) if isinstance(config, PromptLearningConfig) and is_trainable: raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: config.inference_mode = not is_trainable if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): model = cls(model, config, adapter_name) else: model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name) model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) return model def _setup_prompt_encoder(self, adapter_name: str): config = self.peft_config[adapter_name] self.prompt_encoder = torch.nn.ModuleDict({}) self.prompt_tokens = {} transformer_backbone = None for name, module in self.base_model.named_children(): for param in module.parameters(): param.requires_grad = False if isinstance(module, PreTrainedModel): # Make sure to freeze Tranformers model if transformer_backbone is None: transformer_backbone = module self.transformer_backbone_name = name if config.num_transformer_submodules is None: config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1 for named_param, value in list(transformer_backbone.named_parameters()): if value.shape[0] == self.base_model.config.vocab_size: self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", "")) break if config.peft_type == PeftType.PROMPT_TUNING: prompt_encoder = PromptEmbedding(config, self.word_embeddings) elif config.peft_type == PeftType.P_TUNING: prompt_encoder = PromptEncoder(config) elif config.peft_type == PeftType.PREFIX_TUNING: prompt_encoder = PrefixEncoder(config) else: raise ValueError("Not supported") self.prompt_encoder.update(torch.nn.ModuleDict({adapter_name: prompt_encoder})) self.prompt_tokens[adapter_name] = torch.arange( config.num_virtual_tokens * config.num_transformer_submodules ).long() def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel): r""" Prepares the model for gradient checkpointing if necessary """ if not (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)): if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model def get_prompt_embedding_to_save(self, adapter_name: str): """ Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type != PeftType.LORA`. """ prompt_encoder = self.prompt_encoder[adapter_name] prompt_tokens = ( self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device) ) if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING: prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens] prompt_embeddings = prompt_encoder(prompt_tokens) return prompt_embeddings[0].detach().cpu() def get_prompt(self, batch_size: int): """ Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`. """ peft_config = self.active_peft_config prompt_encoder = self.prompt_encoder[self.active_adapter] prompt_tokens = ( self.prompt_tokens[self.active_adapter] .unsqueeze(0) .expand(batch_size, -1) .to(prompt_encoder.embedding.weight.device) ) if peft_config.peft_type == PeftType.PREFIX_TUNING: prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens] if peft_config.inference_mode: past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) else: past_key_values = prompt_encoder(prompt_tokens) if self.base_model_torch_dtype is not None: past_key_values = past_key_values.to(self.base_model_torch_dtype) past_key_values = past_key_values.view( batch_size, peft_config.num_virtual_tokens, peft_config.num_layers * 2, peft_config.num_attention_heads, peft_config.token_dim // peft_config.num_attention_heads, ) if peft_config.num_transformer_submodules == 2: past_key_values = torch.cat([past_key_values, past_key_values], dim=2) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split( peft_config.num_transformer_submodules * 2 ) if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] past_key_values = post_process_fn(past_key_values) return past_key_values else: if peft_config.inference_mode: prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) else: prompts = prompt_encoder(prompt_tokens) return prompts def print_trainable_parameters(self): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in self.named_parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel all_param += num_params if param.requires_grad: trainable_params += num_params print( f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}" ) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.base_model, name) def forward(self, *args: Any, **kwargs: Any): """ Forward pass of the model. """ return self.get_base_model()(*args, **kwargs) @contextmanager def disable_adapter(self): """ Disables the adapter module. """ try: if isinstance(self.peft_config[self.active_adapter], PromptLearningConfig): old_forward = self.forward self.forward = self.base_model.forward else: self.base_model.disable_adapter_layers() yield finally: if isinstance(self.peft_config[self.active_adapter], PromptLearningConfig): self.forward = old_forward else: self.base_model.enable_adapter_layers() def get_base_model(self): """ Returns the base model. """ return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model def add_adapter(self, adapter_name: str, peft_config: PeftConfig): if peft_config.peft_type != self.peft_type: raise ValueError( f"Cannot combine adapters with different peft types. " f"Found {self.peft_type} and {peft_config.peft_type}." ) self.peft_config[adapter_name] = peft_config if isinstance(peft_config, PromptLearningConfig): self._setup_prompt_encoder(adapter_name) else: self.base_model.add_adapter(adapter_name, peft_config) self.set_additional_trainable_modules(peft_config, adapter_name) def set_additional_trainable_modules(self, peft_config, adapter_name): if getattr(peft_config, "modules_to_save", None) is not None: if self.modules_to_save is None: self.modules_to_save = set(peft_config.modules_to_save) else: self.modules_to_save.update(peft_config.modules_to_save) _set_trainable(self, adapter_name) @classmethod def _split_kwargs(cls, kwargs: Dict[str, Any]): hf_hub_download_kwargs = {} other_kwargs = {} for key, value in kwargs.items(): if key in inspect.signature(hf_hub_download).parameters: hf_hub_download_kwargs[key] = value else: other_kwargs[key] = value return hf_hub_download_kwargs, other_kwargs def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any): from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs) if adapter_name not in self.peft_config: # load the config peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[ PeftConfig._get_peft_type( model_id, subfolder=kwargs.get("subfolder", None), revision=kwargs.get("revision", None), cache_dir=kwargs.get("cache_dir", None), ) ].from_pretrained( model_id, subfolder=kwargs.get("subfolder", None), revision=kwargs.get("revision", None), cache_dir=kwargs.get("cache_dir", None), ) if isinstance(peft_config, PromptLearningConfig) and is_trainable: raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: peft_config.inference_mode = not is_trainable self.add_adapter(adapter_name, peft_config) # load weights if any path = os.path.join(model_id, kwargs["subfolder"]) if kwargs.get("subfolder", None) is not None else model_id if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) use_safetensors = True elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): filename = os.path.join(path, WEIGHTS_NAME) use_safetensors = False else: has_remote_safetensors_file = hub_file_exists( model_id, SAFETENSORS_WEIGHTS_NAME, revision=kwargs.get("revision", None) ) use_safetensors = has_remote_safetensors_file if has_remote_safetensors_file: # Priority 1: load safetensors weights filename = hf_hub_download( model_id, SAFETENSORS_WEIGHTS_NAME, subfolder=kwargs.get("subfolder", None), **hf_hub_download_kwargs, ) else: try: filename = hf_hub_download( model_id, WEIGHTS_NAME, subfolder=kwargs.get("subfolder", None), **hf_hub_download_kwargs ) except EntryNotFoundError: raise ValueError( f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}." ) if use_safetensors: adapters_weights = safe_load_file(filename, device="cuda" if torch.cuda.is_available() else "cpu") else: adapters_weights = torch.load( filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) # load the weights into the model load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name) if ( (getattr(self, "hf_device_map", None) is not None) and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) and len(self.peft_config) == 1 ): device_map = kwargs.get("device_map", "auto") max_memory = kwargs.get("max_memory", None) offload_dir = kwargs.get("offload_folder", None) offload_index = kwargs.get("offload_index", None) dispatch_model_kwargs = {} # Safety checker for previous `accelerate` versions # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/ if "offload_index" in inspect.signature(dispatch_model).parameters: dispatch_model_kwargs["offload_index"] = offload_index no_split_module_classes = self._no_split_modules if device_map != "sequential": max_memory = get_balanced_memory( self, max_memory=max_memory, no_split_module_classes=no_split_module_classes, low_zero=(device_map == "balanced_low_0"), ) if isinstance(device_map, str): device_map = infer_auto_device_map( self, max_memory=max_memory, no_split_module_classes=no_split_module_classes ) dispatch_model( self, device_map=device_map, offload_dir=offload_dir, **dispatch_model_kwargs, ) hook = AlignDevicesHook(io_same_device=True) if isinstance(self.peft_config[adapter_name], PromptLearningConfig): remove_hook_from_submodules(self.prompt_encoder) add_hook_to_module(self.get_base_model(), hook) # Set model in evaluation mode to deactivate Dropout modules by default if not is_trainable: self.eval() return load_result def set_adapter(self, adapter_name: str): """ Sets the active adapter. """ if adapter_name not in self.peft_config: raise ValueError(f"Adapter {adapter_name} not found.") self.active_adapter = adapter_name if not isinstance(self.peft_config[adapter_name], PromptLearningConfig): self.base_model.set_adapter(adapter_name) _set_adapter(self, adapter_name) @property def active_peft_config(self): return self.peft_config[self.active_adapter] def create_or_update_model_card(self, output_dir: str): """ Updates or create model card to include information about peft: 1. Adds `peft` library tag 2. Adds peft version 3. Adds quantization information if it was used """ # Adds `peft` library tag add_library_to_model_card(output_dir) with open(os.path.join(output_dir, "README.md"), "r") as f: lines = f.readlines() quantization_config = None if hasattr(self.config, "quantization_config"): quantization_config = self.config.quantization_config.to_dict() training_config_text = "" # Adds quantization information if it was used if quantization_config is not None: training_config_text += "\nThe following `bitsandbytes` quantization config was used during training:\n" training_config_text += "\n".join([f"- {name}: {value}" for name, value in quantization_config.items()]) training_config_text += "\n" training_procedure_heading = "## Training procedure\n" if training_procedure_heading in lines: lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) else: lines.append(f"{training_procedure_heading}\n{training_config_text}") # Adds peft version framework_block_heading = "### Framework versions\n" if framework_block_heading in lines: lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}\n") else: lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n") # write the lines back to README.md with open(os.path.join(output_dir, "README.md"), "w") as f: f.writelines(lines) class PeftModelForSequenceClassification(PeftModel): """ Peft model for sequence classification tasks. Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. - **cls_layer_name** (`str`) -- The name of the classification layer. Example: ```py >>> from transformers import AutoModelForSequenceClassification >>> from peft import PeftModelForSequenceClassification, get_peft_config >>> config = { ... "peft_type": "PREFIX_TUNING", ... "task_type": "SEQ_CLS", ... "inference_mode": False, ... "num_virtual_tokens": 20, ... "token_dim": 768, ... "num_transformer_submodules": 1, ... "num_attention_heads": 12, ... "num_layers": 12, ... "encoder_hidden_size": 768, ... "prefix_projection": False, ... "postprocess_past_key_value_function": None, ... } >>> peft_config = get_peft_config(config) >>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased") >>> peft_model = PeftModelForSequenceClassification(model, peft_config) >>> peft_model.print_trainable_parameters() trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117 ``` """ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): super().__init__(model, peft_config, adapter_name) if self.modules_to_save is None: self.modules_to_save = {"classifier", "score"} else: self.modules_to_save.update({"classifier", "score"}) for name, _ in self.base_model.named_children(): if any(module_name in name for module_name in self.modules_to_save): self.cls_layer_name = name break # to make sure classifier layer is trainable _set_trainable(self, adapter_name) def forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict peft_config = self.active_peft_config if not isinstance(peft_config, PromptLearningConfig): return self.base_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) batch_size = input_ids.shape[0] if attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if kwargs.get("position_ids", None) is not None: warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") kwargs["position_ids"] = None kwargs.update( { "attention_mask": attention_mask, "labels": labels, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } ) if peft_config.peft_type == PeftType.PREFIX_TUNING: return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) else: if kwargs.get("token_type_ids", None) is not None: kwargs["token_type_ids"] = torch.cat( ( torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), kwargs["token_type_ids"], ), dim=1, ).long() if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) def _prefix_tuning_forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size) fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) kwargs.update( { "input_ids": input_ids, "attention_mask": attention_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, "past_key_values": past_key_values, } ) if "past_key_values" in fwd_params: return self.base_model(labels=labels, **kwargs) else: transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) if "past_key_values" not in fwd_params: raise ValueError("Model does not support past key values which are required for prefix tuning.") outputs = transformer_backbone_name(**kwargs) pooled_output = outputs[1] if len(outputs) > 1 else outputs[0] if "dropout" in [name for name, _ in list(self.base_model.named_children())]: pooled_output = self.base_model.dropout(pooled_output) logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.base_model.num_labels == 1: self.config.problem_type = "regression" elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.base_model.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class PeftModelForCausalLM(PeftModel): """ Peft model for causal language modeling. Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. Example: ```py >>> from transformers import AutoModelForCausalLM >>> from peft import PeftModelForCausalLM, get_peft_config >>> config = { ... "peft_type": "PREFIX_TUNING", ... "task_type": "CAUSAL_LM", ... "inference_mode": False, ... "num_virtual_tokens": 20, ... "token_dim": 1280, ... "num_transformer_submodules": 1, ... "num_attention_heads": 20, ... "num_layers": 36, ... "encoder_hidden_size": 1280, ... "prefix_projection": False, ... "postprocess_past_key_value_function": None, ... } >>> peft_config = get_peft_config(config) >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>> peft_model = PeftModelForCausalLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544 ``` """ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): super().__init__(model, peft_config, adapter_name) # 备份self.base_model_prepare_inputs_for_generation self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation self.base_model._validate_model_kwargs = self.base_model_validate_model_kwargs def forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): peft_config = self.active_peft_config # 添加 # print("kwargs = ", kwargs) if not isinstance(peft_config, PromptLearningConfig): if self.base_model.config.model_type == "mpt": if inputs_embeds is not None: raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds") return self.base_model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) return self.base_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) batch_size = input_ids.shape[0] if attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if kwargs.get("position_ids", None) is not None: warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") kwargs["position_ids"] = None if kwargs.get("token_type_ids", None) is not None: warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") kwargs["token_type_ids"] = None kwargs.update( { "attention_mask": attention_mask, "labels": labels, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } ) if peft_config.peft_type == PeftType.PREFIX_TUNING: past_key_values = self.get_prompt(batch_size) return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs) else: if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) # concat prompt labels if labels is not None: prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device) kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) def generate(self, **kwargs): self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation if hasattr(self.base_model, "model"): self.base_model.model.generation_config = self.generation_config else: self.base_model.generation_config = self.generation_config try: # MoeLoRAModel.generate outputs = self.base_model.generate(**kwargs) except: # 引发异常 self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation raise else: self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation return outputs def prepare_inputs_for_generation(self, *args, **kwargs): peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if isinstance(peft_config, PromptLearningConfig): if model_kwargs.get("attention_mask", None) is not None: prefix_attention_mask = torch.ones( model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens ).to(model_kwargs["input_ids"].device) model_kwargs["attention_mask"] = torch.cat( (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 ) if model_kwargs.get("position_ids", None) is not None: warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") model_kwargs["position_ids"] = None if kwargs.get("token_type_ids", None) is not None: warnings.warn( "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" ) kwargs["token_type_ids"] = None if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING: past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) model_kwargs["past_key_values"] = past_key_values else: if model_kwargs["past_key_values"] is None: inputs_embeds = self.word_embeddings(model_kwargs["input_ids"]) prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) prompts = prompts.to(inputs_embeds.dtype) model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1) model_kwargs["input_ids"] = None # !!! model_kwargs["user_embeds"] = None return model_kwargs # !!! def base_model_validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" pass # If a `Cache` instance is passed, checks whether the model is compatible with it if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: raise ValueError( f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " "check the model documentation for supported cache formats." ) # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: model_kwargs.pop(key, None) unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) if "kwargs" in model_args or "model_kwargs" in model_args: model_args |= set(inspect.signature(self.forward).parameters) # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` if self.config.is_encoder_decoder: base_model = getattr(self, self.base_model_prefix, None) # allow encoder kwargs encoder = getattr(self, "encoder", None) # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` # TODO: A better way to handle this. if encoder is None and base_model is not None: encoder = getattr(base_model, "encoder", None) if encoder is not None: encoder_model_args = set(inspect.signature(encoder.forward).parameters) model_args |= encoder_model_args # allow decoder kwargs decoder = getattr(self, "decoder", None) if decoder is None and base_model is not None: decoder = getattr(base_model, "decoder", None) if decoder is not None: decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} # allow assistant_encoder_outputs to be passed if we're doing assisted generating if "assistant_encoder_outputs" in model_kwargs: model_args |= {"assistant_encoder_outputs"} for key, value in model_kwargs.items(): if value is not None and key not in model_args: unused_model_args.append(key) if unused_model_args: raise ValueError( f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" " generate arguments will also show up in this list)" ) class PeftModelForSeq2SeqLM(PeftModel): """ Peft model for sequence-to-sequence language modeling. Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. Example: ```py >>> from transformers import AutoModelForSeq2SeqLM >>> from peft import PeftModelForSeq2SeqLM, get_peft_config >>> config = { ... "peft_type": "LORA", ... "task_type": "SEQ_2_SEQ_LM", ... "inference_mode": False, ... "r": 8, ... "target_modules": ["q", "v"], ... "lora_alpha": 32, ... "lora_dropout": 0.1, ... "fan_in_fan_out": False, ... "enable_lora": None, ... "bias": "none", ... } >>> peft_config = get_peft_config(config) >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> peft_model = PeftModelForSeq2SeqLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566 ``` """ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): super().__init__(model, peft_config, adapter_name) self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation self.base_model_prepare_encoder_decoder_kwargs_for_generation = ( self.base_model._prepare_encoder_decoder_kwargs_for_generation ) def forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): peft_config = self.active_peft_config if not isinstance(peft_config, PromptLearningConfig): return self.base_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_inputs_embeds=decoder_inputs_embeds, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) batch_size = input_ids.shape[0] if decoder_attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( decoder_attention_mask.device ) decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1) if kwargs.get("position_ids", None) is not None: warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") kwargs["position_ids"] = None if kwargs.get("token_type_ids", None) is not None: warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") kwargs["token_type_ids"] = None kwargs.update( { "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "labels": labels, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } ) if peft_config.peft_type == PeftType.PREFIX_TUNING: past_key_values = self.get_prompt(batch_size) return self.base_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs ) elif peft_config.peft_type in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]: if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) if attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( attention_mask.device ) kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) else: if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) if decoder_inputs_embeds is None and decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) decoder_inputs_embeds = self.word_embeddings(decoder_input_ids) if attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( attention_mask.device ) kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) # concat prompt labels if labels is not None: if peft_config.num_transformer_submodules == 1: kwargs["labels"] = labels elif peft_config.num_transformer_submodules == 2: prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device) kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) if peft_config.num_transformer_submodules == 1: return self.base_model(inputs_embeds=inputs_embeds, **kwargs) elif peft_config.num_transformer_submodules == 2: decoder_inputs_embeds = torch.cat( (prompts[:, peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1 ) return self.base_model( inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs ) def generate(self, **kwargs): peft_config = self.active_peft_config self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( self._prepare_encoder_decoder_kwargs_for_generation ) try: if not isinstance(peft_config, PromptLearningConfig): outputs = self.base_model.generate(**kwargs) else: if "input_ids" not in kwargs: raise ValueError("input_ids must be provided for Peft model generation") if kwargs.get("position_ids", None) is not None: warnings.warn( "Position ids are not supported for parameter efficient tuning. Ignoring position ids." ) kwargs["position_ids"] = None if kwargs.get("token_type_ids", None) is not None: warnings.warn( "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" ) kwargs["token_type_ids"] = None if peft_config.peft_type == PeftType.PREFIX_TUNING: outputs = self.base_model.generate(**kwargs) elif peft_config.peft_type in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]: kwargs = deepcopy(kwargs) if "encoder_outputs" in kwargs: del kwargs["encoder_ouputs"] warnings.warn( "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it." ) input_ids = kwargs.pop("input_ids") inputs_embeds = self.word_embeddings(input_ids) batch_size = inputs_embeds.shape[0] prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) kwargs["inputs_embeds"] = inputs_embeds if "attention_mask" in kwargs: prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to( kwargs["attention_mask"].device ) kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1) return self.base_model.generate(**kwargs) else: raise NotImplementedError except: self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( self.base_model_prepare_encoder_decoder_kwargs_for_generation ) raise else: self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( self.base_model_prepare_encoder_decoder_kwargs_for_generation ) return outputs def prepare_inputs_for_generation(self, *args, **kwargs): peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING: batch_size = model_kwargs["decoder_input_ids"].shape[0] past_key_values = self.get_prompt(batch_size) model_kwargs["past_key_values"] = past_key_values return model_kwargs class PeftModelForTokenClassification(PeftModel): """ Peft model for token classification tasks. Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. - **cls_layer_name** (`str`) -- The name of the classification layer. Example: ```py >>> from transformers import AutoModelForSequenceClassification >>> from peft import PeftModelForTokenClassification, get_peft_config >>> config = { ... "peft_type": "PREFIX_TUNING", ... "task_type": "TOKEN_CLS", ... "inference_mode": False, ... "num_virtual_tokens": 20, ... "token_dim": 768, ... "num_transformer_submodules": 1, ... "num_attention_heads": 12, ... "num_layers": 12, ... "encoder_hidden_size": 768, ... "prefix_projection": False, ... "postprocess_past_key_value_function": None, ... } >>> peft_config = get_peft_config(config) >>> model = AutoModelForTokenClassification.from_pretrained("bert-base-cased") >>> peft_model = PeftModelForTokenClassification(model, peft_config) >>> peft_model.print_trainable_parameters() trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117 ``` """ def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"): super().__init__(model, peft_config, adapter_name) if self.modules_to_save is None: self.modules_to_save = {"classifier", "score"} else: self.modules_to_save.update({"classifier", "score"}) for name, _ in self.base_model.named_children(): if any(module_name in name for module_name in self.modules_to_save): self.cls_layer_name = name break # to make sure classifier layer is trainable _set_trainable(self, adapter_name) def forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): peft_config = self.active_peft_config return_dict = return_dict if return_dict is not None else self.config.use_return_dict if not isinstance(peft_config, PromptLearningConfig): return self.base_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) batch_size = input_ids.shape[0] if attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if kwargs.get("position_ids", None) is not None: warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") kwargs["position_ids"] = None kwargs.update( { "attention_mask": attention_mask, "labels": labels, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } ) if peft_config.peft_type == PeftType.PREFIX_TUNING: return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) else: if kwargs.get("token_type_ids", None) is not None: kwargs["token_type_ids"] = torch.cat( ( torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), kwargs["token_type_ids"], ), dim=1, ).long() if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) def _prefix_tuning_forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size) fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) kwargs.update( { "input_ids": input_ids, "attention_mask": attention_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, "past_key_values": past_key_values, } ) if "past_key_values" in fwd_params: return self.base_model(labels=labels, **kwargs) else: transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) if "past_key_values" not in fwd_params: raise ValueError("Model does not support past key values which are required for prefix tuning.") outputs = transformer_backbone_name(**kwargs) sequence_output = outputs[0] if "dropout" in [name for name, _ in list(self.base_model.named_children())]: sequence_output = self.base_model.dropout(sequence_output) logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class PeftModelForQuestionAnswering(PeftModel): """ Peft model for extractive question answering. Args: model ([`~transformers.PreTrainedModel`]): Base transformer model. peft_config ([`PeftConfig`]): Peft config. **Attributes**: - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model. - **cls_layer_name** (`str`) -- The name of the classification layer. Example: ```py >>> from transformers import AutoModelForQuestionAnswering >>> from peft import PeftModelForQuestionAnswering, get_peft_config >>> config = { ... "peft_type": "LORA", ... "task_type": "QUESTION_ANS", ... "inference_mode": False, ... "r": 16, ... "target_modules": ["query", "value"], ... "lora_alpha": 32, ... "lora_dropout": 0.05, ... "fan_in_fan_out": False, ... "bias": "none", ... } >>> peft_config = get_peft_config(config) >>> model = AutoModelForQuestionAnswering.from_pretrained("bert-base-cased") >>> peft_model = PeftModelForQuestionAnswering(model, peft_config) >>> peft_model.print_trainable_parameters() trainable params: 592900 || all params: 108312580 || trainable%: 0.5473971721475013 ``` """ def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"): super().__init__(model, peft_config, adapter_name) if self.modules_to_save is None: self.modules_to_save = {"qa_outputs"} else: self.modules_to_save.update({"qa_outputs"}) for name, _ in self.base_model.named_children(): if any(module_name in name for module_name in self.modules_to_save): self.cls_layer_name = name break # to make sure classifier layer is trainable _set_trainable(self, adapter_name) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): peft_config = self.active_peft_config return_dict = return_dict if return_dict is not None else self.config.use_return_dict if not isinstance(peft_config, PromptLearningConfig): return self.base_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, start_positions=start_positions, end_positions=end_positions, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) batch_size = input_ids.shape[0] if attention_mask is not None: # concat prompt attention mask prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if kwargs.get("position_ids", None) is not None: warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") kwargs["position_ids"] = None kwargs.update( { "attention_mask": attention_mask, "start_positions": start_positions, "end_positions": end_positions, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } ) if peft_config.peft_type == PeftType.PREFIX_TUNING: return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) else: if kwargs.get("token_type_ids", None) is not None: kwargs["token_type_ids"] = torch.cat( ( torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), kwargs["token_type_ids"], ), dim=1, ).long() if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) prompts = self.get_prompt(batch_size=batch_size) prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) def _prefix_tuning_forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size) fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) kwargs.update( { "input_ids": input_ids, "attention_mask": attention_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, "past_key_values": past_key_values, } ) if "past_key_values" in fwd_params: return self.base_model(start_positions=start_positions, end_positions=end_positions, **kwargs) else: transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) if "past_key_values" not in fwd_params: raise ValueError("Model does not support past key values which are required for prefix tuning.") outputs = transformer_backbone_name(**kwargs) sequence_output = outputs[0] if "dropout" in [name for name, _ in list(self.base_model.named_children())]: sequence_output = self.base_model.dropout(sequence_output) logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )