crystal-technologies's picture
Upload 303 files
de4ade4
raw
history blame contribute delete
10.8 kB
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
import logging
import os
from typing import Mapping, Union
# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
PreTrainedTokenizerBase)
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.utils import init_empty_weights
try:
from peft.peft_model import PeftModel
model_types = PeftModel, transformers.PreTrainedModel
except ImportError:
model_types = transformers.PreTrainedModel
__all__ = ['ComposerHFCausalLM']
log = logging.getLogger(__name__)
class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
Args:
om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library.
if DictConfig, the following keys are required:
cfg.pretrained_model_name_or_path (str): The name of or local path to
the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
cfg.config_overrides (dict, optional): An optional dictionary of keyword
arguments that override the default configuration associated with
cfg.pretrained_model_name_or_path.
cfg.pretrained (bool): Whether to instantiate the model with pre-trained
weights coming from cfg.pretrained_model_name_or_path. If ``True``,
cfg.config_overrides must be compatible with the pre-trained weights.
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
initialize the model on. Currently, `meta` is only supported when
cfg.pretrained is ``False``. Default: ``'cpu'``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
def __init__(self, om_model_config: Union[DictConfig,
transformers.PreTrainedModel,
nn.Module],
tokenizer: PreTrainedTokenizerBase):
# set up training and eval metrics
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
# if we are passed a DictConfig, we need to instantiate the model
if isinstance(om_model_config, DictConfig):
if not om_model_config.get('trust_remote_code',
True) and om_model_config.get(
'pretrained_model_name_or_path',
None).startswith('mosaicml/mpt'):
raise ValueError(
'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
+
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)
if not om_model_config.get('use_train_metrics', True):
train_metrics = []
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)
attr = getattr(config, k)
# attempt to disallow typos in nested configs
if isinstance(attr, Mapping):
extra_keys = [
_k for _k in v.keys() if _k not in attr.keys()
]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. ' +
f'Extra keys: {extra_keys}. ' +
f'Expected (a subset of) keys: {list(attr.keys())}.'
)
getattr(config, k).update(v)
# necessary case to allow for rope_scaling to be overriden in llama config
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
else:
setattr(config, k, v)
load_in_8bit = om_model_config.get('load_in_8bit', False)
# below we set up the device to initialize the model on
init_device = om_model_config.get('init_device', 'cpu')
# Get the device we want to initialize, and use the
# reolved version to initialize the HF model
resolved_init_device = hf_get_init_device(init_device)
# We need to have all non-zero local ranks be not-pretrained
# Rank 0 will still be pretrained, and distribute the weights appropriately
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False
# initialize the model on the correct device
if resolved_init_device == 'cpu':
if om_model_config.pretrained:
model = AutoModelForCausalLM.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
config=config)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
elif resolved_init_device == 'meta':
if om_model_config.pretrained:
raise ValueError(
'Setting cfg.pretrained=True is not supported when init_device="meta".'
)
with init_empty_weights(include_buffers=False):
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
else:
raise ValueError(
f'init_device="{init_device}" must be either "cpu" or "meta".'
)
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
dist.barrier()
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
z_loss = om_model_config.get('z_loss', 0.0)
attention_patch_type = om_model_config.get('attention_patch_type',
None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)
log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
# elif the model is either a PeftModel or a PreTrainedModel
elif isinstance(om_model_config, model_types):
model = om_model_config
init_device = 'cpu'
z_loss = 0.0
# else, unsupported type
else:
raise ValueError(
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
)
composer_model = super().__init__(model=model,
shift_labels=True,
tokenizer=tokenizer,
metrics=train_metrics,
eval_metrics=eval_metrics,
z_loss=z_loss,
init_device=init_device)
return composer_model