sgoel30's picture
Upload 34 files
d8ed92a verified
raw
history blame
4.25 kB
import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from umap import UMAP
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from transformers import AutoModel, AutoTokenizer
path = "/workspace/sg666/MDpLM/benchmarks/Generation"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
esm_model_path = "facebook/esm2_t33_650M_UR50D"
# Loads ESM model and tokenizer to embed the sequences
def load_esm2_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
return tokenizer, model
def get_latents(model, tokenizer, sequence):
inputs = tokenizer(sequence, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().tolist()
return embeddings
# Load a random set of 100 human and reviewed sequences from uniprot
def parse_fasta_file(file_path):
with open(file_path, 'r') as file:
lines = file.readlines()
sequences = []
current_seq = []
current_type = "UniProt"
for line in lines:
line = line.strip()
if line.startswith('>'):
if current_seq:
sequences.append(("".join(current_seq), current_type))
current_seq = []
else:
current_seq.append(line)
if current_seq:
sequences.append(("".join(current_seq), current_type))
return pd.DataFrame(sequences, columns=["Sequence", "Sequence Source"]).sample(100).reset_index(drop=True)
# Obtain/clean sequences generated from ProtGPT2 fine-tuned on membrane sequences
protgpt2_sequences = pd.read_csv(path + "/ProtGPT2/protgpt2_generated_sequences.csv")
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('<|ENDOFTEXT|>', '', regex=False)
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('""', '', regex=False)
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('\n', '', regex=False)
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('X', 'G', regex=False)
protgpt2_sequences.drop(columns=['Perplexity'], inplace=True)
protgpt2_sequences['Sequence Source'] = "ProtGPT2"
bad_sequences = []
for seq in protgpt2_sequences['Sequence']:
for residue in seq:
if residue in ['B', 'U', 'Z', 'O']:
bad_sequences.append(seq)
protgpt2_sequences = protgpt2_sequences[~protgpt2_sequences['Sequence'].isin(bad_sequences)]
# Load MDpLM generated sequences
memdlm_sequences = pd.read_csv(path + "/mdlm_de-novo_generation_results.csv")
memdlm_sequences.rename(columns={"Generated Sequence": "Sequence"}, inplace=True)
memdlm_sequences.drop(columns=['Perplexity'], inplace=True)
memdlm_sequences['Sequence Source'] = "MeMDLM"
memdlm_sequences.reset_index(drop=True, inplace=True)
# Load UniProt sequences
# fasta_file_path = path + "/uniprot_human_and_reviewed.fasta"
# other_sequences = parse_fasta_file(fasta_file_path)
# Load test set sequences
other_sequences = pd.read_csv("/workspace/sg666/MDpLM/data/membrane/test.csv")
other_sequences['Sequence Source'] = "Test Set"
other_sequences = other_sequences.sample(100)
# Combine all sequences
data = pd.concat([memdlm_sequences, protgpt2_sequences, other_sequences])
# Load ESM model and tokenizer for embeddings
tokenizer, model = load_esm2_model(esm_model_path)
model = model.to(device)
# Embed the sequences
data['Embeddings'] = data['Sequence'].apply(lambda sequence: get_latents(model, tokenizer, sequence))
data = data.reset_index(drop=True)
umap_df = pd.DataFrame(data['Embeddings'].tolist())
umap_df.index = data['Sequence Source']
# Do PCA
umap = UMAP(n_components=2)
umap_features = umap.fit_transform(umap_df)
umap_df['UMAP1'] = umap_features[:, 0]
umap_df['UMAP2'] = umap_features[:, 1]
# Visualize the PCA
plt.figure(figsize=(8, 5),dpi=300)
sns.scatterplot(x='UMAP1', y='UMAP2', hue='Sequence Source', data=umap_df, palette=['#297272', '#ff7477', "#9A77D0"], s=100)
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.title(f'ESM-650M Embeddings of Membrane Protein Sequences')
plt.savefig('esm_umap.png')
plt.show()