Fill-Mask
Transformers
Safetensors
esm
File size: 1,885 Bytes
9a73cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from fuson_plm.training.model import FusOnpLM
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import logging
import torch
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "1"

# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = 'checkpoints/old_splits_snp_2000_ft_11layers_Q_b8_lr5e-05_mask0.15-08-12-2024-12:42:48/checkpoint_epoch_1.pth'
model = AutoModel.from_pretrained(model_name)              # initialize model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
model.to(device)

# Example fusion oncoprotein sequence: MLLT10:PICALM, associated with Acute Myeloid Leukemia (LAML)  
# Amino acids 1-80 are derived from the head gene, MLLT10
# Amino acids 81-119 are derived from the tail gene, PICALM
sequence = "MVSSDRPVSLEDEVSHSMKEMIGGCCVCSDERGWAENPLVYCDGHGCSVAVHQACYGIVQVPTGPWFCRKCESQERAARVPPQMGSVPVMTQPTLIYSQPVMRPPNPFGPVSGAQIQFM"

# Tokenize the input sequence
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Get the embeddings
with torch.no_grad():
    outputs = model(**inputs)
    # The embeddings are in the last_hidden_state tensor
    embeddings = outputs.last_hidden_state
    # remove extra dimension
    embeddings = embeddings.squeeze(0)
    # remove BOS and EOS tokens
    embeddings = embeddings[1:-1, :]

# Convert embeddings to numpy array (if needed)
embeddings = embeddings.cpu().numpy()

print("Sequence length: ", len(sequence))
print("Per-residue embeddings shape:", embeddings.shape)