SlimPajama-DC / register_btlm.py
Jason0214's picture
Upload folder using huggingface_hub
b02c73e verified
raw
history blame
475 Bytes
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from configuration_btlm import BTLMConfig
from modeling_btlm import BTLMModel, BTLMLMHeadModel
AutoConfig.register("btlm", BTLMConfig)
AutoModel.register(BTLMConfig, BTLMModel)
AutoModelForCausalLM.register(BTLMConfig, BTLMLMHeadModel)
AutoTokenizer.register(BTLMConfig, fast_tokenizer_class=PreTrainedTokenizerFast)