Spaces:
Running
Running
File size: 10,608 Bytes
569f484 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import torch
import torch.nn as nn
import deepspeed
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import *
from transformers.integrations import is_deepspeed_zero3_enabled
class CPMTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
if labels is not None:
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = outputs.logits.view(-1,
self.model.config.vocab_size).contiguous()
labels = labels.view(-1).long().contiguous()
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = (
False
if len(self.label_names) == 0
else all(inputs.get(k) is not None for k in self.label_names)
)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = (
True if len(self.label_names) == 0 and return_loss else False
)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(
self.model.config, "keys_to_ignore_at_inference", []
)
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name)
for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(
v
for k, v in raw_outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = smp_nested_concat(logits_mb)
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(
v for k, v in raw_outputs.items() if k not in ignore_keys
)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(
model, inputs, return_outputs=True
)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(
v
for k, v in outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(
v for k, v in outputs.items() if k not in ignore_keys
)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
del inputs
torch.cuda.empty_cache()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|