from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from configuration_btlm import BTLMConfig from modeling_btlm import BTLMModel, BTLMLMHeadModel AutoConfig.register("btlm", BTLMConfig) AutoModel.register(BTLMConfig, BTLMModel) AutoModelForCausalLM.register(BTLMConfig, BTLMLMHeadModel) AutoTokenizer.register(BTLMConfig, fast_tokenizer_class=PreTrainedTokenizerFast)