import torch import torch.nn as nn import deepspeed from transformers import Trainer from transformers.trainer_pt_utils import nested_detach from transformers.utils import is_sagemaker_mp_enabled from transformers.trainer import * from transformers.integrations import is_deepspeed_zero3_enabled class CPMTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): if "labels" in inputs: labels = inputs.pop("labels") else: labels = None if not self.args.use_lora: outputs = self.model(data = inputs, use_cache=False) else: with self.model._enable_peft_forward_hooks(**inputs): outputs = self.model.base_model(data = inputs, use_cache=False) if labels is not None: # Flatten the tokens loss_fct = nn.CrossEntropyLoss() logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous() labels = labels.view(-1).long().contiguous() # Enable model parallelism labels = labels.to(logits.device) loss = loss_fct(logits, labels) else: if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." ) # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] return (loss, outputs) if return_outputs else loss def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on `model` using `inputs`. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to evaluate. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (`bool`): Whether or not to return the loss only. ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. Return: Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ has_labels = ( False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) ) # For CLIP-like models capable of returning loss values. # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` # is `True` in `model.forward`. return_loss = inputs.get("return_loss", None) if return_loss is None: return_loss = self.can_return_loss loss_without_labels = ( True if len(self.label_names) == 0 and return_loss else False ) inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): ignore_keys = getattr( self.model.config, "keys_to_ignore_at_inference", [] ) else: ignore_keys = [] # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. if has_labels or loss_without_labels: labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None with torch.no_grad(): if is_sagemaker_mp_enabled(): raw_outputs = smp_forward_only(model, inputs) if has_labels or loss_without_labels: if isinstance(raw_outputs, dict): loss_mb = raw_outputs["loss"] logits_mb = tuple( v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"] ) else: loss_mb = raw_outputs[0] logits_mb = raw_outputs[1:] loss = loss_mb.reduce_mean().detach().cpu() logits = smp_nested_concat(logits_mb) else: loss = None if isinstance(raw_outputs, dict): logits_mb = tuple( v for k, v in raw_outputs.items() if k not in ignore_keys ) else: logits_mb = raw_outputs logits = smp_nested_concat(logits_mb) else: if has_labels or loss_without_labels: with self.compute_loss_context_manager(): loss, outputs = self.compute_loss( model, inputs, return_outputs=True ) loss = loss.mean().detach() if isinstance(outputs, dict): logits = tuple( v for k, v in outputs.items() if k not in ignore_keys + ["loss"] ) else: logits = outputs[1:] else: loss = None with self.compute_loss_context_manager(): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple( v for k, v in outputs.items() if k not in ignore_keys ) else: logits = outputs # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] return (loss, logits, labels) def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. Return: `torch.Tensor`: The tensor with training loss on this batch. """ model.train() inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) del inputs torch.cuda.empty_cache() if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() if isinstance(unwrap_model(self.model), supported_classes): unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if self.args.save_safetensors: safetensors.torch.save_file( state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} ) else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))