crystal-technologies's picture
Upload 303 files
de4ade4
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Re-usable :class:`.ComposerModel` for LLM HF Models."""
from __future__ import annotations
import inspect
from collections import UserDict
from typing import List, Mapping, Optional
import torch
import transformers
from composer.models.huggingface import HuggingFaceModel
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase
from transformers.utils.generic import ModelOutput
from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100
class HuggingFaceModelWithZLoss(HuggingFaceModel):
"""Wrapper around HuggingFaceModel.
This adds z-loss, which is used in some training contexts,
and is a convenient way to patch features that are generically
useful for HF models.
See use of z_loss in PaLM: https://arxiv.org/abs/2204.02311v3, Section 5.
Also, from https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666:
Two uses of z_loss are:
- To keep the logits from drifting too far from zero, which can cause
unacceptable roundoff errors in bfloat16.
- To encourage the logits to be normalized log-probabilities.
Handles preparation for FSDP wrapping.
"""
def __init__(self,
model: transformers.PreTrainedModel,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
metrics: Optional[List[Metric]] = None,
eval_metrics: Optional[List[Metric]] = None,
z_loss: float = 0.0,
shift_labels: bool = False,
init_device: Optional[str] = None):
super().__init__(model,
tokenizer,
use_logits=True,
metrics=metrics,
eval_metrics=eval_metrics,
shift_labels=shift_labels)
self.z_loss = float(z_loss)
if self.z_loss < 0.0:
raise ValueError(f'z_loss(={z_loss}) cannot be negative.')
self.model_forward_args = inspect.getfullargspec(
self.model.forward).args
# inspect.getfullargspec HuggingFace quantized model could not return args correctly
if not self.model_forward_args:
self.model_forward_args = inspect.signature(
self.model.forward).parameters.keys()
# Note: We need to add the FSDP related attributes to the model AFTER the super init,
# so that the (possible) embedding resizing doesn't destroy them
prepare_hf_model_for_fsdp(self.model, init_device)
# This provides support for meta initialization when using FSDP
self.model.param_init_fn = lambda module: self.model._init_weights(
module)
def forward(self, batch: Mapping):
if isinstance(batch, dict) or isinstance(batch, UserDict):
# Further input validation is left to the huggingface forward call
batch = {
k: v for k, v in batch.items() if k in self.model_forward_args
}
output = self.model(**batch) # type: ignore (thirdparty)
else:
raise ValueError(
'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model'
)
return output
def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
loss, logits = outputs['loss'], outputs['logits']
else:
# loss is at index 0 in the output tuple, logits are at index 1
loss, logits = outputs[:2]
if self.z_loss == 0.0:
return loss
# Add a z_loss to the standard loss
logits_flat = logits.view(-1, logits.size(-1))
labels_flat = batch['labels'].view(-1)
log_z = torch.logsumexp(logits_flat[labels_flat != _HF_IGNORE_INDEX],
dim=1)
log_z2 = log_z**2
z_loss = log_z2.mean() * self.z_loss
if self.config.use_return_dict:
outputs['loss'] += z_loss
return outputs['loss']
else:
outputs[0] += z_loss
return outputs[0]