# Copyright (c) OpenMMLab. All rights reserved. import contextlib from typing import Optional import transformers from mmengine.registry import Registry from transformers import AutoConfig, PreTrainedModel from transformers.models.auto.auto_factory import _BaseAutoModelClass from mmpretrain.registry import MODELS, TOKENIZER def register_hf_tokenizer( cls: Optional[type] = None, registry: Registry = TOKENIZER, ): """Register HuggingFace-style PreTrainedTokenizerBase class.""" if cls is None: # use it as a decorator: @register_hf_tokenizer() def _register(cls): register_hf_tokenizer(cls=cls) return cls return _register def from_pretrained(**kwargs): if ('pretrained_model_name_or_path' not in kwargs and 'name_or_path' not in kwargs): raise TypeError( f'{cls.__name__}.from_pretrained() missing required ' "argument 'pretrained_model_name_or_path' or 'name_or_path'.") # `pretrained_model_name_or_path` is too long for config, # add an alias name `name_or_path` here. name_or_path = kwargs.pop('pretrained_model_name_or_path', kwargs.pop('name_or_path')) return cls.from_pretrained(name_or_path, **kwargs) registry._register_module(module=from_pretrained, module_name=cls.__name__) return cls _load_hf_pretrained_model = True @contextlib.contextmanager def no_load_hf_pretrained_model(): global _load_hf_pretrained_model _load_hf_pretrained_model = False yield _load_hf_pretrained_model = True def register_hf_model( cls: Optional[type] = None, registry: Registry = MODELS, ): """Register HuggingFace-style PreTrainedModel class.""" if cls is None: # use it as a decorator: @register_hf_tokenizer() def _register(cls): register_hf_model(cls=cls) return cls return _register if issubclass(cls, _BaseAutoModelClass): get_config = AutoConfig.from_pretrained from_config = cls.from_config elif issubclass(cls, PreTrainedModel): get_config = cls.config_class.from_pretrained from_config = cls else: raise TypeError('Not auto model nor pretrained model of huggingface.') def build(**kwargs): if ('pretrained_model_name_or_path' not in kwargs and 'name_or_path' not in kwargs): raise TypeError( f'{cls.__name__} missing required argument ' '`pretrained_model_name_or_path` or `name_or_path`.') # `pretrained_model_name_or_path` is too long for config, # add an alias name `name_or_path` here. name_or_path = kwargs.pop('pretrained_model_name_or_path', kwargs.pop('name_or_path')) if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model: return cls.from_pretrained(name_or_path, **kwargs) else: cfg = get_config(name_or_path, **kwargs) return from_config(cfg) registry._register_module(module=build, module_name=cls.__name__) return cls register_hf_model(transformers.AutoModelForCausalLM)