File size: 4,248 Bytes
d8ed92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from umap import UMAP
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from transformers import AutoModel, AutoTokenizer

path = "/workspace/sg666/MDpLM/benchmarks/Generation"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
esm_model_path = "facebook/esm2_t33_650M_UR50D"

# Loads ESM model and tokenizer to embed the sequences
def load_esm2_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    return tokenizer, model

def get_latents(model, tokenizer, sequence):
    inputs = tokenizer(sequence, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().tolist()
    return embeddings

# Load a random set of 100 human and reviewed sequences from uniprot
def parse_fasta_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    sequences = []
    current_seq = []
    current_type = "UniProt"
    
    for line in lines:
        line = line.strip()
        if line.startswith('>'):
            if current_seq:
                sequences.append(("".join(current_seq), current_type))
                current_seq = []
        else:
            current_seq.append(line)
    if current_seq:
        sequences.append(("".join(current_seq), current_type))
    
    return pd.DataFrame(sequences, columns=["Sequence", "Sequence Source"]).sample(100).reset_index(drop=True)


# Obtain/clean sequences generated from ProtGPT2 fine-tuned on membrane sequences
protgpt2_sequences = pd.read_csv(path + "/ProtGPT2/protgpt2_generated_sequences.csv")
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('<|ENDOFTEXT|>', '', regex=False)
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('""', '', regex=False)
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('\n', '', regex=False)
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('X', 'G', regex=False)
protgpt2_sequences.drop(columns=['Perplexity'], inplace=True)
protgpt2_sequences['Sequence Source'] = "ProtGPT2"
bad_sequences = []
for seq in protgpt2_sequences['Sequence']:
    for residue in seq:
        if residue in ['B', 'U', 'Z', 'O']:
            bad_sequences.append(seq)
protgpt2_sequences = protgpt2_sequences[~protgpt2_sequences['Sequence'].isin(bad_sequences)]


# Load MDpLM generated sequences
memdlm_sequences = pd.read_csv(path + "/mdlm_de-novo_generation_results.csv")
memdlm_sequences.rename(columns={"Generated Sequence": "Sequence"}, inplace=True)
memdlm_sequences.drop(columns=['Perplexity'], inplace=True)
memdlm_sequences['Sequence Source'] = "MeMDLM"
memdlm_sequences.reset_index(drop=True, inplace=True)

# Load UniProt sequences
# fasta_file_path = path + "/uniprot_human_and_reviewed.fasta"
# other_sequences = parse_fasta_file(fasta_file_path)

# Load test set sequences
other_sequences = pd.read_csv("/workspace/sg666/MDpLM/data/membrane/test.csv")
other_sequences['Sequence Source'] = "Test Set"
other_sequences = other_sequences.sample(100)

# Combine all sequences
data = pd.concat([memdlm_sequences, protgpt2_sequences, other_sequences])


# Load ESM model and tokenizer for embeddings
tokenizer, model = load_esm2_model(esm_model_path)
model = model.to(device)


# Embed the sequences
data['Embeddings'] = data['Sequence'].apply(lambda sequence: get_latents(model, tokenizer, sequence))
data = data.reset_index(drop=True)
umap_df = pd.DataFrame(data['Embeddings'].tolist())
umap_df.index = data['Sequence Source']


# Do PCA
umap = UMAP(n_components=2)
umap_features = umap.fit_transform(umap_df)
umap_df['UMAP1'] = umap_features[:, 0]
umap_df['UMAP2'] = umap_features[:, 1]

# Visualize the PCA
plt.figure(figsize=(8, 5),dpi=300)
sns.scatterplot(x='UMAP1', y='UMAP2', hue='Sequence Source', data=umap_df, palette=['#297272', '#ff7477', "#9A77D0"], s=100)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title(f'ESM-650M Embeddings of Membrane Protein Sequences')
plt.savefig('esm_umap.png')
plt.show()