Fill-Mask
Transformers
Safetensors
esm
FusOn-pLM / fuson_plm /data /split_vis.py
svincoff's picture
data cleaning, blast, and splitting code with source data, also deleting unnecessary files
6efd653
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import entropy
from sklearn.manifold import TSNE
import pickle
import pandas as pd
import os
from fuson_plm.utils.logging import log_update
from fuson_plm.utils.visualizing import set_font, visualize_splits
def main():
set_font()
train_clusters = pd.read_csv('splits/train_cluster_split.csv')
val_clusters = pd.read_csv('splits/val_cluster_split.csv')
test_clusters = pd.read_csv('splits/test_cluster_split.csv')
clusters = pd.concat([train_clusters,val_clusters,test_clusters])
fuson_db = pd.read_csv('fuson_db.csv')
# Get the sequence IDs of all clustered benchmark sequences.
benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
# Use benchmark_seq_ids to find which clusters contain benchmark sequences.
benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps)
## Add seq_id to every source data file that is saved from visualize_splits
seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
files_to_edit = os.listdir("splits/split_vis")
files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"]
log_update(f"Adding seq_ids to the following files: {files_to_edit}")
for fname in files_to_edit:
source_data_file = pd.read_csv(f"splits/split_vis/{fname}")
if "sequence" in list(source_data_file.columns):
source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict)
source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False)
if __name__ == "__main__":
main()