import torch | |
from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, AutoModelForMaskedLM | |
def load_esm2_model(esm_model_path): | |
tokenizer = AutoTokenizer.from_pretrained(esm_model_path) | |
model = AutoModelForMaskedLM.from_pretrained(esm_model_path) | |
return tokenizer, model | |
def load_mlm_model(esm_model_path, ckpt_path): | |
tokenizer = AutoTokenizer.from_pretrained(esm_model_path) | |
model = AutoModelForMaskedLM.from_pretrained(ckpt_path) | |
return tokenizer, model |