|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.stats import entropy |
|
from sklearn.manifold import TSNE |
|
import pickle |
|
import pandas as pd |
|
import os |
|
from fuson_plm.utils.logging import log_update |
|
from fuson_plm.utils.visualizing import set_font, visualize_splits |
|
|
|
def main(): |
|
set_font() |
|
train_clusters = pd.read_csv('splits/train_cluster_split.csv') |
|
val_clusters = pd.read_csv('splits/val_cluster_split.csv') |
|
test_clusters = pd.read_csv('splits/test_cluster_split.csv') |
|
|
|
clusters = pd.concat([train_clusters,val_clusters,test_clusters]) |
|
|
|
fuson_db = pd.read_csv('fuson_db.csv') |
|
|
|
benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] |
|
|
|
benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() |
|
|
|
visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps) |
|
|
|
|
|
seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) |
|
files_to_edit = os.listdir("splits/split_vis") |
|
files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"] |
|
log_update(f"Adding seq_ids to the following files: {files_to_edit}") |
|
|
|
for fname in files_to_edit: |
|
source_data_file = pd.read_csv(f"splits/split_vis/{fname}") |
|
if "sequence" in list(source_data_file.columns): |
|
source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict) |
|
source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False) |
|
|
|
if __name__ == "__main__": |
|
main() |