File size: 491 Bytes
d8ed92a |
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch
from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, AutoModelForMaskedLM
def load_esm2_model(esm_model_path):
tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
return tokenizer, model
def load_mlm_model(esm_model_path, ckpt_path):
tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
model = AutoModelForMaskedLM.from_pretrained(ckpt_path)
return tokenizer, model |