File size: 10,766 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
import logging
import os
from typing import Mapping, Union
# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
PreTrainedTokenizerBase)
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.utils import init_empty_weights
try:
from peft.peft_model import PeftModel
model_types = PeftModel, transformers.PreTrainedModel
except ImportError:
model_types = transformers.PreTrainedModel
__all__ = ['ComposerHFCausalLM']
log = logging.getLogger(__name__)
class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
Args:
om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library.
if DictConfig, the following keys are required:
cfg.pretrained_model_name_or_path (str): The name of or local path to
the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
cfg.config_overrides (dict, optional): An optional dictionary of keyword
arguments that override the default configuration associated with
cfg.pretrained_model_name_or_path.
cfg.pretrained (bool): Whether to instantiate the model with pre-trained
weights coming from cfg.pretrained_model_name_or_path. If ``True``,
cfg.config_overrides must be compatible with the pre-trained weights.
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
initialize the model on. Currently, `meta` is only supported when
cfg.pretrained is ``False``. Default: ``'cpu'``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
def __init__(self, om_model_config: Union[DictConfig,
transformers.PreTrainedModel,
nn.Module],
tokenizer: PreTrainedTokenizerBase):
# set up training and eval metrics
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
# if we are passed a DictConfig, we need to instantiate the model
if isinstance(om_model_config, DictConfig):
if not om_model_config.get('trust_remote_code',
True) and om_model_config.get(
'pretrained_model_name_or_path',
None).startswith('mosaicml/mpt'):
raise ValueError(
'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
+
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)
if not om_model_config.get('use_train_metrics', True):
train_metrics = []
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)
attr = getattr(config, k)
# attempt to disallow typos in nested configs
if isinstance(attr, Mapping):
extra_keys = [
_k for _k in v.keys() if _k not in attr.keys()
]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. ' +
f'Extra keys: {extra_keys}. ' +
f'Expected (a subset of) keys: {list(attr.keys())}.'
)
getattr(config, k).update(v)
# necessary case to allow for rope_scaling to be overriden in llama config
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
else:
setattr(config, k, v)
load_in_8bit = om_model_config.get('load_in_8bit', False)
# below we set up the device to initialize the model on
init_device = om_model_config.get('init_device', 'cpu')
# Get the device we want to initialize, and use the
# reolved version to initialize the HF model
resolved_init_device = hf_get_init_device(init_device)
# We need to have all non-zero local ranks be not-pretrained
# Rank 0 will still be pretrained, and distribute the weights appropriately
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False
# initialize the model on the correct device
if resolved_init_device == 'cpu':
if om_model_config.pretrained:
model = AutoModelForCausalLM.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
config=config)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
elif resolved_init_device == 'meta':
if om_model_config.pretrained:
raise ValueError(
'Setting cfg.pretrained=True is not supported when init_device="meta".'
)
with init_empty_weights(include_buffers=False):
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
else:
raise ValueError(
f'init_device="{init_device}" must be either "cpu" or "meta".'
)
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
dist.barrier()
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
z_loss = om_model_config.get('z_loss', 0.0)
attention_patch_type = om_model_config.get('attention_patch_type',
None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)
log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
# elif the model is either a PeftModel or a PreTrainedModel
elif isinstance(om_model_config, model_types):
model = om_model_config
init_device = 'cpu'
z_loss = 0.0
# else, unsupported type
else:
raise ValueError(
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
)
composer_model = super().__init__(model=model,
shift_labels=True,
tokenizer=tokenizer,
metrics=train_metrics,
eval_metrics=eval_metrics,
z_loss=z_loss,
init_device=init_device)
return composer_model
|