File size: 597 Bytes
d061944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import config
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM

def load_esm2_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    masked_model = AutoModelForMaskedLM.from_pretrained(model_name)
    embedding_model = AutoModel.from_pretrained(model_name)
    return tokenizer, masked_model, embedding_model

def get_latents(model, tokenizer, sequence, device):
    inputs = tokenizer(sequence, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs).last_hidden_state.squeeze(0)
    return outputs