Fill-Mask
Transformers
Safetensors
esm
File size: 1,816 Bytes
1e6a1f0
 
 
 
 
 
 
 
6efd653
1e6a1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6efd653
 
 
 
 
 
 
 
 
 
 
 
 
1e6a1f0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import entropy
from sklearn.manifold import TSNE
import pickle
import pandas as pd
import os
from fuson_plm.utils.logging import log_update
from fuson_plm.utils.visualizing import set_font, visualize_splits

def main():
    set_font()
    train_clusters = pd.read_csv('splits/train_cluster_split.csv')
    val_clusters = pd.read_csv('splits/val_cluster_split.csv')
    test_clusters = pd.read_csv('splits/test_cluster_split.csv')
    
    clusters = pd.concat([train_clusters,val_clusters,test_clusters])
    
    fuson_db = pd.read_csv('fuson_db.csv')
    # Get the sequence IDs of all clustered benchmark sequences. 
    benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
    # Use benchmark_seq_ids to find which clusters contain benchmark sequences.
    benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
    
    visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps)
    
    ## Add seq_id to every source data file that is saved from visualize_splits
    seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
    files_to_edit = os.listdir("splits/split_vis")
    files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"]
    log_update(f"Adding seq_ids to the following files: {files_to_edit}")
    
    for fname in files_to_edit:
        source_data_file = pd.read_csv(f"splits/split_vis/{fname}")
        if "sequence" in list(source_data_file.columns):
            source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict)
            source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False)

if __name__ == "__main__":
    main()