|
import torch |
|
import pandas as pd |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
from umap import UMAP |
|
from sklearn.manifold import TSNE |
|
from sklearn.decomposition import PCA |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
path = "/workspace/sg666/MDpLM/benchmarks/Generation" |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
esm_model_path = "facebook/esm2_t33_650M_UR50D" |
|
|
|
|
|
def load_esm2_model(model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name).to(device) |
|
return tokenizer, model |
|
|
|
def get_latents(model, tokenizer, sequence): |
|
inputs = tokenizer(sequence, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().tolist() |
|
return embeddings |
|
|
|
|
|
def parse_fasta_file(file_path): |
|
with open(file_path, 'r') as file: |
|
lines = file.readlines() |
|
|
|
sequences = [] |
|
current_seq = [] |
|
current_type = "UniProt" |
|
|
|
for line in lines: |
|
line = line.strip() |
|
if line.startswith('>'): |
|
if current_seq: |
|
sequences.append(("".join(current_seq), current_type)) |
|
current_seq = [] |
|
else: |
|
current_seq.append(line) |
|
if current_seq: |
|
sequences.append(("".join(current_seq), current_type)) |
|
|
|
return pd.DataFrame(sequences, columns=["Sequence", "Sequence Source"]).sample(100).reset_index(drop=True) |
|
|
|
|
|
|
|
protgpt2_sequences = pd.read_csv(path + "/ProtGPT2/protgpt2_generated_sequences.csv") |
|
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('<|ENDOFTEXT|>', '', regex=False) |
|
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('""', '', regex=False) |
|
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('\n', '', regex=False) |
|
protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('X', 'G', regex=False) |
|
protgpt2_sequences.drop(columns=['Perplexity'], inplace=True) |
|
protgpt2_sequences['Sequence Source'] = "ProtGPT2" |
|
bad_sequences = [] |
|
for seq in protgpt2_sequences['Sequence']: |
|
for residue in seq: |
|
if residue in ['B', 'U', 'Z', 'O']: |
|
bad_sequences.append(seq) |
|
protgpt2_sequences = protgpt2_sequences[~protgpt2_sequences['Sequence'].isin(bad_sequences)] |
|
|
|
|
|
|
|
memdlm_sequences = pd.read_csv(path + "/mdlm_de-novo_generation_results.csv") |
|
memdlm_sequences.rename(columns={"Generated Sequence": "Sequence"}, inplace=True) |
|
memdlm_sequences.drop(columns=['Perplexity'], inplace=True) |
|
memdlm_sequences['Sequence Source'] = "MeMDLM" |
|
memdlm_sequences.reset_index(drop=True, inplace=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
other_sequences = pd.read_csv("/workspace/sg666/MDpLM/data/membrane/test.csv") |
|
other_sequences['Sequence Source'] = "Test Set" |
|
other_sequences = other_sequences.sample(100) |
|
|
|
|
|
data = pd.concat([memdlm_sequences, protgpt2_sequences, other_sequences]) |
|
|
|
|
|
|
|
tokenizer, model = load_esm2_model(esm_model_path) |
|
model = model.to(device) |
|
|
|
|
|
|
|
data['Embeddings'] = data['Sequence'].apply(lambda sequence: get_latents(model, tokenizer, sequence)) |
|
data = data.reset_index(drop=True) |
|
umap_df = pd.DataFrame(data['Embeddings'].tolist()) |
|
umap_df.index = data['Sequence Source'] |
|
|
|
|
|
|
|
umap = UMAP(n_components=2) |
|
umap_features = umap.fit_transform(umap_df) |
|
umap_df['UMAP1'] = umap_features[:, 0] |
|
umap_df['UMAP2'] = umap_features[:, 1] |
|
|
|
|
|
plt.figure(figsize=(8, 5),dpi=300) |
|
sns.scatterplot(x='UMAP1', y='UMAP2', hue='Sequence Source', data=umap_df, palette=['#297272', '#ff7477', "#9A77D0"], s=100) |
|
plt.xlabel('UMAP1') |
|
plt.ylabel('UMAP2') |
|
plt.title(f'ESM-650M Embeddings of Membrane Protein Sequences') |
|
plt.savefig('esm_umap.png') |
|
plt.show() |