|
|
|
|
|
|
|
"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" |
|
|
|
import logging |
|
import os |
|
from typing import Mapping, Union |
|
|
|
|
|
import transformers |
|
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy, |
|
InContextLearningLMAccuracy, |
|
InContextLearningLMExpectedCalibrationError, |
|
InContextLearningMCExpectedCalibrationError, |
|
InContextLearningMultipleChoiceAccuracy, |
|
InContextLearningQAAccuracy, |
|
LanguageCrossEntropy, LanguagePerplexity) |
|
from composer.utils import dist |
|
from omegaconf import DictConfig |
|
from torch import nn |
|
from transformers import (AutoConfig, AutoModelForCausalLM, |
|
PreTrainedTokenizerBase) |
|
|
|
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device |
|
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss |
|
from llmfoundry.models.layers.llama_attention_monkeypatch import \ |
|
get_llama_attention_patch_fn |
|
from llmfoundry.models.utils import init_empty_weights |
|
|
|
try: |
|
from peft.peft_model import PeftModel |
|
model_types = PeftModel, transformers.PreTrainedModel |
|
|
|
except ImportError: |
|
model_types = transformers.PreTrainedModel |
|
|
|
__all__ = ['ComposerHFCausalLM'] |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class ComposerHFCausalLM(HuggingFaceModelWithZLoss): |
|
"""Configures a :class:`.HuggingFaceModel` around a Causal LM. |
|
|
|
Args: |
|
om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library. |
|
if DictConfig, the following keys are required: |
|
cfg.pretrained_model_name_or_path (str): The name of or local path to |
|
the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel). |
|
cfg.config_overrides (dict, optional): An optional dictionary of keyword |
|
arguments that override the default configuration associated with |
|
cfg.pretrained_model_name_or_path. |
|
cfg.pretrained (bool): Whether to instantiate the model with pre-trained |
|
weights coming from cfg.pretrained_model_name_or_path. If ``True``, |
|
cfg.config_overrides must be compatible with the pre-trained weights. |
|
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to |
|
initialize the model on. Currently, `meta` is only supported when |
|
cfg.pretrained is ``False``. Default: ``'cpu'``. |
|
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. |
|
""" |
|
|
|
def __init__(self, om_model_config: Union[DictConfig, |
|
transformers.PreTrainedModel, |
|
nn.Module], |
|
tokenizer: PreTrainedTokenizerBase): |
|
|
|
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] |
|
eval_metrics = [ |
|
LanguageCrossEntropy(), |
|
LanguagePerplexity(), |
|
InContextLearningLMAccuracy(), |
|
InContextLearningMultipleChoiceAccuracy(), |
|
InContextLearningQAAccuracy(), |
|
InContextLearningCodeEvalAccuracy(), |
|
InContextLearningLMExpectedCalibrationError(), |
|
InContextLearningMCExpectedCalibrationError() |
|
] |
|
|
|
|
|
if isinstance(om_model_config, DictConfig): |
|
if not om_model_config.get('trust_remote_code', |
|
True) and om_model_config.get( |
|
'pretrained_model_name_or_path', |
|
None).startswith('mosaicml/mpt'): |
|
raise ValueError( |
|
'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' |
|
+ |
|
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' |
|
) |
|
|
|
if not om_model_config.get('use_train_metrics', True): |
|
train_metrics = [] |
|
|
|
|
|
trust_remote_code = om_model_config.get('trust_remote_code', True) |
|
use_auth_token = om_model_config.get('use_auth_token', False) |
|
config = AutoConfig.from_pretrained( |
|
om_model_config.pretrained_model_name_or_path, |
|
trust_remote_code=trust_remote_code, |
|
use_auth_token=use_auth_token, |
|
) |
|
|
|
|
|
for k, v in om_model_config.get('config_overrides', {}).items(): |
|
if not hasattr(config, k): |
|
raise ValueError( |
|
f'config does not have attribute "{k}" to override ({k}: {v}).' |
|
) |
|
|
|
attr = getattr(config, k) |
|
|
|
if isinstance(attr, Mapping): |
|
extra_keys = [ |
|
_k for _k in v.keys() if _k not in attr.keys() |
|
] |
|
if extra_keys: |
|
raise ValueError( |
|
f'Config dict override got unknown keys. ' + |
|
f'Extra keys: {extra_keys}. ' + |
|
f'Expected (a subset of) keys: {list(attr.keys())}.' |
|
) |
|
getattr(config, k).update(v) |
|
|
|
elif attr is None and isinstance(v, Mapping): |
|
setattr(config, k, {}) |
|
getattr(config, k).update(v) |
|
else: |
|
setattr(config, k, v) |
|
|
|
load_in_8bit = om_model_config.get('load_in_8bit', False) |
|
|
|
|
|
init_device = om_model_config.get('init_device', 'cpu') |
|
|
|
|
|
|
|
resolved_init_device = hf_get_init_device(init_device) |
|
|
|
|
|
|
|
if dist.get_local_rank() != 0 and init_device == 'mixed': |
|
om_model_config.pretrained = False |
|
|
|
|
|
if resolved_init_device == 'cpu': |
|
if om_model_config.pretrained: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
om_model_config.pretrained_model_name_or_path, |
|
trust_remote_code=trust_remote_code, |
|
use_auth_token=use_auth_token, |
|
load_in_8bit=load_in_8bit, |
|
config=config) |
|
else: |
|
model = AutoModelForCausalLM.from_config( |
|
config, |
|
trust_remote_code=trust_remote_code, |
|
) |
|
elif resolved_init_device == 'meta': |
|
if om_model_config.pretrained: |
|
raise ValueError( |
|
'Setting cfg.pretrained=True is not supported when init_device="meta".' |
|
) |
|
with init_empty_weights(include_buffers=False): |
|
model = AutoModelForCausalLM.from_config( |
|
config, |
|
trust_remote_code=trust_remote_code, |
|
) |
|
else: |
|
raise ValueError( |
|
f'init_device="{init_device}" must be either "cpu" or "meta".' |
|
) |
|
|
|
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' |
|
if dist.get_local_rank() == 0: |
|
with open(signal_file_path, 'wb') as f: |
|
f.write(b'local_rank0_completed_download') |
|
|
|
|
|
|
|
with dist.local_rank_zero_download_and_wait(signal_file_path): |
|
|
|
dist.barrier() |
|
|
|
if dist.get_local_rank() == 0: |
|
os.remove(signal_file_path) |
|
|
|
z_loss = om_model_config.get('z_loss', 0.0) |
|
|
|
attention_patch_type = om_model_config.get('attention_patch_type', |
|
None) |
|
if attention_patch_type is not None: |
|
if model.config.model_type != 'llama': |
|
raise ValueError( |
|
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' |
|
) |
|
|
|
log.debug( |
|
f'Patching llama attention with {attention_patch_type} attention' |
|
) |
|
from transformers.models.llama.modeling_llama import \ |
|
LlamaAttention |
|
LlamaAttention.forward = get_llama_attention_patch_fn( |
|
attention_patch_type) |
|
model.config.use_cache = False |
|
|
|
|
|
elif isinstance(om_model_config, model_types): |
|
model = om_model_config |
|
init_device = 'cpu' |
|
z_loss = 0.0 |
|
|
|
|
|
else: |
|
raise ValueError( |
|
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}' |
|
) |
|
|
|
composer_model = super().__init__(model=model, |
|
shift_labels=True, |
|
tokenizer=tokenizer, |
|
metrics=train_metrics, |
|
eval_metrics=eval_metrics, |
|
z_loss=z_loss, |
|
init_device=init_device) |
|
|
|
return composer_model |
|
|