MeMDLM / benchmarks /MLM /pretrained_models.py
sgoel30's picture
Upload 34 files
d8ed92a verified
raw
history blame
491 Bytes
import torch
from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, AutoModelForMaskedLM
def load_esm2_model(esm_model_path):
tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
return tokenizer, model
def load_mlm_model(esm_model_path, ckpt_path):
tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
model = AutoModelForMaskedLM.from_pretrained(ckpt_path)
return tokenizer, model