Fill-Mask
Transformers
Safetensors
esm
File size: 7,573 Bytes
8d9d9da
 
 
 
 
 
 
c43fbc6
8d9d9da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import time
import pandas as pd
import numpy as np
import pickle
import os

from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
import fuson_plm.benchmarking.puncta.config as config
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams

def check_splits(df):
    # make sure everything has a split
    if len(df.loc[df['split'].isna()])>0:
        raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)")
    # make sure the only things are train and test
    if len({'train','test'} - set(df['split'].unique()))!=0:
        raise Exception("Error: splits column should only have \'train\' and \'test\'.")
    # make sure there are no duplicate sequences
    if len(df.loc[df['aa_seq'].duplicated()])>0:
        raise Exception("Error: duplicate sequences provided")
    
def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"):
    """
    task = 'nucleus', 'cytoplasm', or 'formation'
    """
    # unpack the details dictioanry
    benchmark_model_type = details['model_type']
    benchmark_model_name = details['model']
    benchmark_model_epoch = details['epoch']
    
    # prepare train and test sets for model
    train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True)
    test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True)
    
    X_train = np.array(train_split['embedding'].tolist())
    y_train = np.array(train_split[task].tolist())
    X_test = np.array(test_split['embedding'].tolist())
    y_test = np.array(test_split[task].tolist())
        
    # Train the final model on all the data
    clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method)
    
    # Evaluate it
    automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh)
    
    # Add the model details back in
    cols = list(automatic_stats_df.columns)
    automatic_stats_df['Model Type'] = [benchmark_model_type]
    automatic_stats_df['Model Name'] = [benchmark_model_name]
    automatic_stats_df['Model Epoch'] = [benchmark_model_epoch]
    newcols = ['Model Type','Model Name','Model Epoch'] + cols
    automatic_stats_df = automatic_stats_df[newcols]
    
    cols = list(custom_stats_df.columns)
    custom_stats_df['Model Type'] = [benchmark_model_type]
    custom_stats_df['Model Name'] = [benchmark_model_name]
    custom_stats_df['Model Epoch'] = [benchmark_model_epoch]
    newcols = ['Model Type','Model Name','Model Epoch'] + cols
    custom_stats_df = custom_stats_df[newcols]
    
    # Save automatic results (for nucleus and cytoplasm)
    if task!="formation":
        automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv'
        if not(os.path.exists(automatic_stats_path)):
            automatic_stats_df.to_csv(automatic_stats_path,index=False)
        else:
            automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False)
    
    # Save custom threshold results (only if it's formation)
    if task=="formation":
        custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv'
        if not(os.path.exists(custom_stats_path)):
            custom_stats_df.to_csv(custom_stats_path,index=False)
        else:
            custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False)
    
def main():
    # make output directory for this run
    os.makedirs('results',exist_ok=True)
    output_dir = f'results/{get_local_time()}'
    os.makedirs(output_dir,exist_ok=True)
    
    with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'):
        # print configurations 
        print_configpy(config)
        
        # Verify that the environment variables are set correctly 
        os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
        log_update("\nChecking on environment variables...")
        log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
        
        # make embeddings if needed
        all_embedding_paths = embed_dataset_for_benchmark(
                                            fuson_ckpts=config.FUSONPLM_CKPTS, 
                                            input_data_path='splits.csv', input_fname='FOdb_puncta_sequences', 
                                            average=True, seq_col='aa_seq',
                                            benchmark_fusonplm=config.BENCHMARK_FUSONPLM, 
                                            benchmark_esm=config.BENCHMARK_ESM, 
                                            benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML, 
                                            benchmark_prott5 = config.BENCHMARK_PROTT5,
                                            overwrite=config.PERMISSION_TO_OVERWRITE)
        
        # load the splits with labels
        splits = pd.read_csv('splits.csv')
        # perform some sanity checks on the splits
        check_splits(splits)
        n_train = len(splits.loc[splits['split']=='train'])
        n_test = len(splits.loc[splits['split']=='test'])
        log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs")
        
        # set training constants
        train_params = CustomParams(
            N_ESTIMATORS = 50,
            TREE_METHOD = "hist",
            CLASS1_THRESHOLDS = {
                'nucleus': 0.83,
                'cytoplasm': 0.83,
                'formation': 0.83
            },
        )
        log_update("\nTraining configs:")
        train_params.print_config(indent='\t')
        
        log_update("\nTraining models")
        # loop through the embedding paths and train each one
        for embedding_path, details in all_embedding_paths.items():
            log_update(f"\tBenchmarking embeddings at: {embedding_path}")
            try:
                with open(embedding_path, "rb") as f:
                    embeddings = pickle.load(f)
            except: 
                raise Exception(f"Cannot read embeddings from {embedding_path}")
            
            # combine the embeddings and splits into one dataframe
            splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
            splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'})
            splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left')
            
            for task in ['nucleus','cytoplasm','formation']:
                log_update(f"\t\tTask: {task}")
                train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task,
                                                    class1_thresh=train_params.CLASS1_THRESHOLDS[task],
                                                    n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD)
        
        log_update(f"\nMaking summary figures:\n")
        log_update(f"\tbar charts...")
        os.makedirs(f"{output_dir}/figures",exist_ok=True)
        make_all_final_bar_charts(output_dir)
        log_update(f"\tDone.")
            
if __name__ == '__main__':
    main()