import torch from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, AutoModelForMaskedLM def load_esm2_model(esm_model_path): tokenizer = AutoTokenizer.from_pretrained(esm_model_path) model = AutoModelForMaskedLM.from_pretrained(esm_model_path) return tokenizer, model def load_mlm_model(esm_model_path, ckpt_path): tokenizer = AutoTokenizer.from_pretrained(esm_model_path) model = AutoModelForMaskedLM.from_pretrained(ckpt_path) return tokenizer, model