File size: 1,816 Bytes
1e6a1f0 6efd653 1e6a1f0 6efd653 1e6a1f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import entropy
from sklearn.manifold import TSNE
import pickle
import pandas as pd
import os
from fuson_plm.utils.logging import log_update
from fuson_plm.utils.visualizing import set_font, visualize_splits
def main():
set_font()
train_clusters = pd.read_csv('splits/train_cluster_split.csv')
val_clusters = pd.read_csv('splits/val_cluster_split.csv')
test_clusters = pd.read_csv('splits/test_cluster_split.csv')
clusters = pd.concat([train_clusters,val_clusters,test_clusters])
fuson_db = pd.read_csv('fuson_db.csv')
# Get the sequence IDs of all clustered benchmark sequences.
benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
# Use benchmark_seq_ids to find which clusters contain benchmark sequences.
benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps)
## Add seq_id to every source data file that is saved from visualize_splits
seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
files_to_edit = os.listdir("splits/split_vis")
files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"]
log_update(f"Adding seq_ids to the following files: {files_to_edit}")
for fname in files_to_edit:
source_data_file = pd.read_csv(f"splits/split_vis/{fname}")
if "sequence" in list(source_data_file.columns):
source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict)
source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False)
if __name__ == "__main__":
main() |