|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.stats import entropy |
|
from sklearn.manifold import TSNE |
|
import pickle |
|
import pandas as pd |
|
import os |
|
from fuson_plm.utils.logging import log_update |
|
from fuson_plm.utils.visualizing import set_font |
|
|
|
def calculate_aa_composition(sequences): |
|
composition = {} |
|
total_length = sum([len(seq) for seq in sequences]) |
|
|
|
for seq in sequences: |
|
for aa in seq: |
|
if aa in composition: |
|
composition[aa] += 1 |
|
else: |
|
composition[aa] = 1 |
|
|
|
|
|
for aa in composition: |
|
composition[aa] /= total_length |
|
|
|
return composition |
|
|
|
def calculate_shannon_entropy(sequence): |
|
""" |
|
Calculate the Shannon entropy for a given sequence. |
|
|
|
Args: |
|
sequence (str): A sequence of characters (e.g., amino acids or nucleotides). |
|
|
|
Returns: |
|
float: Shannon entropy value. |
|
""" |
|
bases = set(sequence) |
|
counts = [sequence.count(base) for base in bases] |
|
return entropy(counts, base=2) |
|
|
|
def visualize_splits_hist(train_lengths, val_lengths, test_lengths, colormap, savepath=f'../data/splits/length_distributions.png', axes=None): |
|
log_update('\nMaking histogram of length distributions') |
|
|
|
if axes is None: |
|
fig, axes = plt.subplots(1, 3, figsize=(18, 6)) |
|
|
|
|
|
xlabel, ylabel = ['Sequence Length (AA)', 'Frequency'] |
|
|
|
|
|
axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train']) |
|
axes[0].set_xlabel(xlabel, fontsize=24) |
|
axes[0].set_ylabel(ylabel, fontsize=24) |
|
axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})', fontsize=24) |
|
axes[0].grid(True) |
|
axes[0].set_axisbelow(True) |
|
axes[0].tick_params(axis='x', labelsize=24) |
|
axes[0].tick_params(axis='y', labelsize=24) |
|
|
|
|
|
|
|
axes[1].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val']) |
|
axes[1].set_xlabel(xlabel, fontsize=24) |
|
axes[1].set_ylabel(ylabel, fontsize=24) |
|
axes[1].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})', fontsize=24) |
|
axes[1].grid(True) |
|
axes[1].set_axisbelow(True) |
|
axes[1].tick_params(axis='x', labelsize=24) |
|
axes[1].tick_params(axis='y', labelsize=24) |
|
|
|
|
|
axes[2].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test']) |
|
axes[2].set_xlabel(xlabel, fontsize=24) |
|
axes[2].set_ylabel(ylabel, fontsize=24) |
|
axes[2].set_title(f'Test Set Length Distribution (n={len(test_lengths)})', fontsize=24) |
|
axes[2].grid(True) |
|
axes[2].set_axisbelow(True) |
|
axes[2].tick_params(axis='x', labelsize=24) |
|
axes[2].tick_params(axis='y', labelsize=24) |
|
|
|
|
|
if savepath is not None: |
|
plt.tight_layout() |
|
|
|
|
|
plt.savefig(savepath) |
|
|
|
def visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap, savepath='../data/splits/scatterplot.png', ax=None): |
|
log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val") |
|
|
|
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
|
|
|
|
total_test_proteins = sum(test_clustersgb['member count']) |
|
test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps) |
|
benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True) |
|
test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True) |
|
|
|
|
|
train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
|
|
|
|
train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] |
|
train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins']) |
|
val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters'] |
|
val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins']) |
|
test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters'] |
|
test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins |
|
benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters'] |
|
benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins |
|
|
|
|
|
if ax is None: |
|
fig, ax = plt.subplots(figsize=(18, 6)) |
|
|
|
ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train') |
|
ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val') |
|
ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test') |
|
ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'], |
|
marker='o', |
|
linestyle='None', |
|
markerfacecolor=colormap['test'], |
|
markeredgecolor='black', |
|
markeredgewidth=1.5, |
|
label='benchmark' |
|
) |
|
ax.set_ylabel('Percentage of Proteins in Dataset', fontsize=24) |
|
ax.set_xlabel('Cluster Size', fontsize=24) |
|
ax.tick_params(axis='x', labelsize=24) |
|
ax.tick_params(axis='y', labelsize=24) |
|
|
|
ax.legend(fontsize=24,markerscale=4) |
|
|
|
|
|
if savepath is not None: |
|
plt.tight_layout() |
|
plt.savefig(savepath) |
|
log_update(f"\tSaved figure to {savepath}") |
|
|
|
def get_avg_embeddings_for_tsne(train_sequences, val_sequences, test_sequences, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'): |
|
embeddings = {} |
|
|
|
try: |
|
with open(embedding_path, 'rb') as f: |
|
embeddings = pickle.load(f) |
|
|
|
train_embeddings = [v for k, v in embeddings.items() if k in train_sequences] |
|
val_embeddings = [v for k, v in embeddings.items() if k in val_sequences] |
|
test_embeddings = [v for k, v in embeddings.items() if k in test_sequences] |
|
|
|
return train_embeddings, val_embeddings, test_embeddings |
|
except: |
|
print("could not open embeddings") |
|
|
|
|
|
def visualize_splits_tsne(train_sequences, val_sequences, test_sequences, colormap, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='../data/splits/tsne_plot.png',ax=None): |
|
""" |
|
Generate a t-SNE plot of embeddings for train, test, and validation. |
|
""" |
|
log_update('\nMaking t-SNE plot of train, val, and test embeddings') |
|
|
|
train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences, val_sequences, test_sequences, embedding_path=embedding_path) |
|
embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings]) |
|
|
|
|
|
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings) |
|
|
|
|
|
tsne = TSNE(n_components=2, random_state=42) |
|
tsne_results = tsne.fit_transform(embeddings) |
|
|
|
|
|
tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2']) |
|
tsne_df['label'] = labels |
|
|
|
|
|
if ax is None: |
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
|
|
|
|
for label, color in colormap.items(): |
|
subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True) |
|
ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6) |
|
|
|
ax.set_title(f't-SNE of {esm_type} Embeddings') |
|
ax.set_xlabel('t-SNE Dimension 1') |
|
ax.set_ylabel('t-SNE Dimension 2') |
|
ax.legend(fontsize=24, markerscale=2) |
|
ax.grid(True) |
|
|
|
|
|
if savepath: |
|
plt.tight_layout() |
|
fig.savefig(savepath) |
|
|
|
def visualize_splits_shannon_entropy(train_sequences, val_sequences, test_sequences, colormap, savepath='../data/splits/shannon_entropy_plot.png',axes=None): |
|
""" |
|
Generate Shannon entropy plots for train, validation, and test sets. |
|
""" |
|
log_update('\nMaking histogram of Shannon Entropy distributions') |
|
train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences] |
|
val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences] |
|
test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences] |
|
|
|
if axes is None: |
|
fig, axes = plt.subplots(1, 3, figsize=(18, 6)) |
|
|
|
axes[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train']) |
|
axes[0].set_title(f'Train Set (n={len(train_entropy)})', fontsize=24) |
|
axes[0].set_xlabel('Shannon Entropy', fontsize=24) |
|
axes[0].set_ylabel('Frequency', fontsize=24) |
|
axes[0].grid(True) |
|
axes[0].set_axisbelow(True) |
|
axes[0].tick_params(axis='x', labelsize=24) |
|
axes[0].tick_params(axis='y', labelsize=24) |
|
|
|
axes[1].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val']) |
|
axes[1].set_title(f'Validation Set (n={len(val_entropy)})', fontsize=24) |
|
axes[1].set_xlabel('Shannon Entropy', fontsize=24) |
|
axes[1].grid(True) |
|
axes[1].set_axisbelow(True) |
|
axes[1].tick_params(axis='x', labelsize=24) |
|
axes[1].tick_params(axis='y', labelsize=24) |
|
|
|
axes[2].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test']) |
|
axes[2].set_title(f'Test Set (n={len(test_entropy)})', fontsize=24) |
|
axes[2].set_xlabel('Shannon Entropy', fontsize=24) |
|
axes[2].grid(True) |
|
axes[2].set_axisbelow(True) |
|
axes[2].tick_params(axis='x', labelsize=24) |
|
axes[2].tick_params(axis='y', labelsize=24) |
|
|
|
if savepath is not None: |
|
plt.tight_layout() |
|
plt.savefig(savepath) |
|
|
|
def visualize_splits_aa_composition(train_sequences, val_sequences, test_sequences,colormap, savepath='../data/splits/aa_comp.png',ax=None): |
|
log_update('\nMaking bar plot of AA composition across each set') |
|
train_comp = calculate_aa_composition(train_sequences) |
|
val_comp = calculate_aa_composition(val_sequences) |
|
test_comp = calculate_aa_composition(test_sequences) |
|
|
|
|
|
comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T |
|
colors = [colormap[col] for col in comp_df.columns] |
|
|
|
|
|
|
|
if ax is None: |
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
else: |
|
fig = ax.get_figure() |
|
|
|
comp_df.plot(kind='bar', color=colors, ax=ax) |
|
ax.set_title('Amino Acid Composition Across Datasets', fontsize=24) |
|
ax.set_xlabel('Amino Acid', fontsize=24) |
|
ax.set_ylabel('Relative Frequency', fontsize=24) |
|
ax.tick_params(axis='x', labelsize=24) |
|
ax.tick_params(axis='y', labelsize=24) |
|
ax.legend(fontsize=16, markerscale=2) |
|
|
|
if savepath is not None: |
|
fig.savefig(savepath) |
|
|
|
def visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None): |
|
colormap = { |
|
'train': train_color, |
|
'val': val_color, |
|
'test': test_color |
|
} |
|
|
|
train_clusters['member length'] = train_clusters['member seq'].str.len() |
|
val_clusters['member length'] = val_clusters['member seq'].str.len() |
|
test_clusters['member length'] = test_clusters['member seq'].str.len() |
|
|
|
|
|
train_lengths = train_clusters['member length'].tolist() |
|
val_lengths = val_clusters['member length'].tolist() |
|
test_lengths = test_clusters['member length'].tolist() |
|
train_sequences = train_clusters['member seq'].tolist() |
|
val_sequences = val_clusters['member seq'].tolist() |
|
test_sequences = test_clusters['member seq'].tolist() |
|
|
|
|
|
fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18)) |
|
|
|
|
|
visualize_splits_hist(train_lengths,val_lengths,test_lengths,colormap, savepath=None,axes=axs[0]) |
|
visualize_splits_shannon_entropy(train_sequences,val_sequences,test_sequences,colormap,savepath=None,axes=axs[1]) |
|
visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap, savepath=None, ax=axs[2, 0]) |
|
visualize_splits_aa_composition(train_sequences,val_sequences,test_sequences, colormap, savepath=None, ax=axs[2, 1]) |
|
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): |
|
visualize_splits_tsne(train_sequences, val_sequences, test_sequences, colormap, savepath=None, ax=axs[2, 2]) |
|
else: |
|
|
|
axs[2, 2].axis('off') |
|
|
|
plt.tight_layout() |
|
fig_combined.savefig('../data/splits/combined_plot.png') |
|
|
|
|
|
visualize_splits_hist(train_clusters['member length'].tolist(), val_clusters['member length'].tolist(), test_clusters['member length'].tolist(),colormap) |
|
visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap) |
|
visualize_splits_aa_composition(train_clusters['member seq'].tolist(), val_clusters['member seq'].tolist(), test_clusters['member seq'].tolist(),colormap) |
|
visualize_splits_shannon_entropy(train_sequences,val_sequences,test_sequences,colormap) |
|
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): |
|
visualize_splits_tsne(train_clusters['member seq'].tolist(), val_clusters['member seq'].tolist(), test_clusters['member seq'].tolist(),colormap) |
|
|
|
def main(): |
|
set_font() |
|
train_clusters = pd.read_csv('splits/train_cluster_split.csv') |
|
val_clusters = pd.read_csv('splits/val_cluster_split.csv') |
|
test_clusters = pd.read_csv('splits/test_cluster_split.csv') |
|
|
|
clusters = pd.concat([train_clusters,val_clusters,test_clusters]) |
|
|
|
fuson_db = pd.read_csv('fuson_db.csv') |
|
|
|
benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] |
|
|
|
benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() |
|
|
|
visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, |
|
esm_embeddings_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl', onehot_embeddings_path=None) |
|
|
|
if __name__ == "__main__": |
|
main() |