File size: 491 Bytes
d8ed92a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, AutoModelForMaskedLM

def load_esm2_model(esm_model_path):
    tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
    model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
    return tokenizer, model

def load_mlm_model(esm_model_path, ckpt_path):
    tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
    model = AutoModelForMaskedLM.from_pretrained(ckpt_path)
    return tokenizer, model