Fill-Mask
Transformers
Safetensors
esm
File size: 6,523 Bytes
1e6a1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pandas as pd
import os
import pickle
from fuson_plm.data.config import SPLIT
from fuson_plm.utils.logging import log_update, open_logfile
from fuson_plm.utils.splitting import split_clusters, check_split_validity
from fuson_plm.utils.visualizing import set_font, visualize_splits
       
def get_benchmark_data(fuson_db_path, clusters):
    """
    """
    # Read the fusion database 
    fuson_db = pd.read_csv(fuson_db_path)
    
    # Get original benchmark sequences, and benchmark sequences that were clustered
    original_benchmark_sequences = fuson_db.loc[(fuson_db['benchmark'].notna()) ]
    benchmark_sequences = fuson_db.loc[
        (fuson_db['benchmark'].notna()) &                               # it's a benchmark sequence
        (fuson_db['aa_seq'].isin(list(clusters['member seq'])))         # it was clustered (it's under the length limit specified for clustering)
    ]['aa_seq'].to_list()
    
    # Get the sequence IDs of all clustered benchmark sequences. 
    benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
    
    # Use benchmark_seq_ids to find which clusters contain benchmark sequences.
    benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
    log_update(f"\t{len(benchmark_sequences)}/{len(original_benchmark_sequences)} benchmarking sequences (only those shorter than config.CLUSTERING[\'max_seq_length\']) were grouped into {len(benchmark_cluster_reps)} clusters. These will be reserved for the test set.")
    
    return benchmark_cluster_reps, benchmark_sequences

def get_training_dfs(train, val, test):
    log_update('\nMaking dataframes for ESM finetuning...')
    
    # Delete cluster-related columns we don't need
    train = train.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
    val = val.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
    test = test.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
    
    return train, val, test
    
def main():
    """
    """
    # Read all the input files
    LOG_PATH = "splitting_log.txt"
    FUSON_DB_PATH = SPLIT.FUSON_DB_PATH
    CLUSTER_OUTPUT_PATH = SPLIT.CLUSTER_OUTPUT_PATH
    RANDOM_STATE_1 = SPLIT.RANDOM_STATE_1
    TEST_SIZE_1 = SPLIT.TEST_SIZE_1
    RANDOM_STATE_2 = SPLIT.RANDOM_STATE_2
    TEST_SIZE_2 = SPLIT.TEST_SIZE_2
    
    # set font
    set_font()
    
    # Prepare the log file
    with open_logfile(LOG_PATH):
    
        log_update("Loaded data-splitting configurations from config.py")
        SPLIT.print_config(indent='\t')
            
        # Prepare directory to save results
        os.makedirs("splits",exist_ok=True)
            
        # Read the clusters and get a list of the representative IDs for splitting
        clusters = pd.read_csv(CLUSTER_OUTPUT_PATH)
        reps = clusters['representative seq_id'].unique().tolist()
        log_update(f"\nPreparing clusters...\n\tCollected {len(reps)} clusters for splitting")
        
        # Get the benchmark cluster representatives and sequences
        benchmark_cluster_reps, benchmark_sequences = get_benchmark_data(FUSON_DB_PATH, clusters)
        
        # Make the splits and extract the results
        splits = split_clusters(reps, benchmark_cluster_reps=benchmark_cluster_reps,
                                random_state_1 = RANDOM_STATE_1, random_state_2 = RANDOM_STATE_2, test_size_1 = TEST_SIZE_1, test_size_2 = TEST_SIZE_2) 
        X_train = splits['X_train']
        X_val = splits['X_val']
        X_test = splits['X_test']
        
        # Make slices of clusters dataframe for train, val, and test
        train_clusters = clusters.loc[clusters['representative seq_id'].isin(X_train)].reset_index(drop=True)
        val_clusters = clusters.loc[clusters['representative seq_id'].isin(X_val)].reset_index(drop=True)
        test_clusters = clusters.loc[clusters['representative seq_id'].isin(X_test)].reset_index(drop=True)
        
        # Check validity
        check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=benchmark_sequences)
        
        # Print min and max sequence lengths
        min_train_seqlen = min(train_clusters['member seq'].str.len())
        max_train_seqlen = max(train_clusters['member seq'].str.len())
        min_val_seqlen = min(val_clusters['member seq'].str.len())
        max_val_seqlen = max(val_clusters['member seq'].str.len())
        min_test_seqlen = min(test_clusters['member seq'].str.len())
        max_test_seqlen = max(test_clusters['member seq'].str.len())
        log_update(f"\nLength breakdown summary...\n\tTrain: min seq length = {min_train_seqlen}, max seq length = {max_train_seqlen}")
        log_update(f"\tVal: min seq length = {min_val_seqlen}, max seq length = {max_val_seqlen}")
        log_update(f"\tTest: min seq length = {min_test_seqlen}, max seq length = {max_test_seqlen}")
        
        # Make plots to visualize the splits
        visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps)
        
        # cols = representative seq_id,member seq_id,representative seq,member seq
        train_clusters.to_csv("../data/splits/train_cluster_split.csv",index=False)
        val_clusters.to_csv("../data/splits/val_cluster_split.csv",index=False)
        test_clusters.to_csv("../data/splits/test_cluster_split.csv",index=False)
        log_update('\nSaved cluster splits to splitting/train_cluster_split.csv, splitting/val_cluster_split.csv, splitting/test_cluster_split.csv')
        cols=','.join(list(train_clusters.columns))
        log_update(f'\tColumns: {cols}')
        
        # IF SnP vectors have been comptued already, make train_df, val_df, test_df: the data that will be input to the training script
        train_df, val_df, test_df = get_training_dfs(train_clusters, val_clusters, test_clusters)
        train_df.to_csv("../data/splits/train_df.csv",index=False)
        val_df.to_csv("../data/splits/val_df.csv",index=False)
        test_df.to_csv("../data/splits/test_df.csv",index=False)
        log_update('\nSaved training dataframes to splits/train_df.csv, splits/val_df.csv, splits/test_df.csv')
        cols=','.join(list(train_df.columns))
        log_update(f'\tColumns: {cols}')

if __name__ == "__main__":
    main()