|
|
|
|
|
|
|
"""Re-usable :class:`.ComposerModel` for LLM HF Models.""" |
|
|
|
from __future__ import annotations |
|
|
|
import inspect |
|
from collections import UserDict |
|
from typing import List, Mapping, Optional |
|
|
|
import torch |
|
import transformers |
|
from composer.models.huggingface import HuggingFaceModel |
|
from torchmetrics import Metric |
|
from transformers import PreTrainedTokenizerBase |
|
from transformers.utils.generic import ModelOutput |
|
|
|
from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp |
|
|
|
|
|
_HF_IGNORE_INDEX = -100 |
|
|
|
|
|
class HuggingFaceModelWithZLoss(HuggingFaceModel): |
|
"""Wrapper around HuggingFaceModel. |
|
|
|
This adds z-loss, which is used in some training contexts, |
|
and is a convenient way to patch features that are generically |
|
useful for HF models. |
|
See use of z_loss in PaLM: https://arxiv.org/abs/2204.02311v3, Section 5. |
|
Also, from https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666: |
|
Two uses of z_loss are: |
|
- To keep the logits from drifting too far from zero, which can cause |
|
unacceptable roundoff errors in bfloat16. |
|
- To encourage the logits to be normalized log-probabilities. |
|
|
|
Handles preparation for FSDP wrapping. |
|
""" |
|
|
|
def __init__(self, |
|
model: transformers.PreTrainedModel, |
|
tokenizer: Optional[PreTrainedTokenizerBase] = None, |
|
metrics: Optional[List[Metric]] = None, |
|
eval_metrics: Optional[List[Metric]] = None, |
|
z_loss: float = 0.0, |
|
shift_labels: bool = False, |
|
init_device: Optional[str] = None): |
|
super().__init__(model, |
|
tokenizer, |
|
use_logits=True, |
|
metrics=metrics, |
|
eval_metrics=eval_metrics, |
|
shift_labels=shift_labels) |
|
self.z_loss = float(z_loss) |
|
if self.z_loss < 0.0: |
|
raise ValueError(f'z_loss(={z_loss}) cannot be negative.') |
|
|
|
self.model_forward_args = inspect.getfullargspec( |
|
self.model.forward).args |
|
|
|
if not self.model_forward_args: |
|
self.model_forward_args = inspect.signature( |
|
self.model.forward).parameters.keys() |
|
|
|
|
|
|
|
prepare_hf_model_for_fsdp(self.model, init_device) |
|
|
|
|
|
self.model.param_init_fn = lambda module: self.model._init_weights( |
|
module) |
|
|
|
def forward(self, batch: Mapping): |
|
if isinstance(batch, dict) or isinstance(batch, UserDict): |
|
|
|
batch = { |
|
k: v for k, v in batch.items() if k in self.model_forward_args |
|
} |
|
output = self.model(**batch) |
|
else: |
|
raise ValueError( |
|
'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model' |
|
) |
|
return output |
|
|
|
def loss(self, outputs: ModelOutput, batch: Mapping): |
|
if self.config.use_return_dict: |
|
loss, logits = outputs['loss'], outputs['logits'] |
|
else: |
|
|
|
loss, logits = outputs[:2] |
|
if self.z_loss == 0.0: |
|
return loss |
|
|
|
|
|
logits_flat = logits.view(-1, logits.size(-1)) |
|
labels_flat = batch['labels'].view(-1) |
|
log_z = torch.logsumexp(logits_flat[labels_flat != _HF_IGNORE_INDEX], |
|
dim=1) |
|
log_z2 = log_z**2 |
|
z_loss = log_z2.mean() * self.z_loss |
|
if self.config.use_return_dict: |
|
outputs['loss'] += z_loss |
|
return outputs['loss'] |
|
else: |
|
outputs[0] += z_loss |
|
return outputs[0] |
|
|