File size: 7,573 Bytes
8d9d9da c43fbc6 8d9d9da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import time
import pandas as pd
import numpy as np
import pickle
import os
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
import fuson_plm.benchmarking.puncta.config as config
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams
def check_splits(df):
# make sure everything has a split
if len(df.loc[df['split'].isna()])>0:
raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)")
# make sure the only things are train and test
if len({'train','test'} - set(df['split'].unique()))!=0:
raise Exception("Error: splits column should only have \'train\' and \'test\'.")
# make sure there are no duplicate sequences
if len(df.loc[df['aa_seq'].duplicated()])>0:
raise Exception("Error: duplicate sequences provided")
def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"):
"""
task = 'nucleus', 'cytoplasm', or 'formation'
"""
# unpack the details dictioanry
benchmark_model_type = details['model_type']
benchmark_model_name = details['model']
benchmark_model_epoch = details['epoch']
# prepare train and test sets for model
train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True)
test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True)
X_train = np.array(train_split['embedding'].tolist())
y_train = np.array(train_split[task].tolist())
X_test = np.array(test_split['embedding'].tolist())
y_test = np.array(test_split[task].tolist())
# Train the final model on all the data
clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method)
# Evaluate it
automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh)
# Add the model details back in
cols = list(automatic_stats_df.columns)
automatic_stats_df['Model Type'] = [benchmark_model_type]
automatic_stats_df['Model Name'] = [benchmark_model_name]
automatic_stats_df['Model Epoch'] = [benchmark_model_epoch]
newcols = ['Model Type','Model Name','Model Epoch'] + cols
automatic_stats_df = automatic_stats_df[newcols]
cols = list(custom_stats_df.columns)
custom_stats_df['Model Type'] = [benchmark_model_type]
custom_stats_df['Model Name'] = [benchmark_model_name]
custom_stats_df['Model Epoch'] = [benchmark_model_epoch]
newcols = ['Model Type','Model Name','Model Epoch'] + cols
custom_stats_df = custom_stats_df[newcols]
# Save automatic results (for nucleus and cytoplasm)
if task!="formation":
automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv'
if not(os.path.exists(automatic_stats_path)):
automatic_stats_df.to_csv(automatic_stats_path,index=False)
else:
automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False)
# Save custom threshold results (only if it's formation)
if task=="formation":
custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv'
if not(os.path.exists(custom_stats_path)):
custom_stats_df.to_csv(custom_stats_path,index=False)
else:
custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False)
def main():
# make output directory for this run
os.makedirs('results',exist_ok=True)
output_dir = f'results/{get_local_time()}'
os.makedirs(output_dir,exist_ok=True)
with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'):
# print configurations
print_configpy(config)
# Verify that the environment variables are set correctly
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
log_update("\nChecking on environment variables...")
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
# make embeddings if needed
all_embedding_paths = embed_dataset_for_benchmark(
fuson_ckpts=config.FUSONPLM_CKPTS,
input_data_path='splits.csv', input_fname='FOdb_puncta_sequences',
average=True, seq_col='aa_seq',
benchmark_fusonplm=config.BENCHMARK_FUSONPLM,
benchmark_esm=config.BENCHMARK_ESM,
benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML,
benchmark_prott5 = config.BENCHMARK_PROTT5,
overwrite=config.PERMISSION_TO_OVERWRITE)
# load the splits with labels
splits = pd.read_csv('splits.csv')
# perform some sanity checks on the splits
check_splits(splits)
n_train = len(splits.loc[splits['split']=='train'])
n_test = len(splits.loc[splits['split']=='test'])
log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs")
# set training constants
train_params = CustomParams(
N_ESTIMATORS = 50,
TREE_METHOD = "hist",
CLASS1_THRESHOLDS = {
'nucleus': 0.83,
'cytoplasm': 0.83,
'formation': 0.83
},
)
log_update("\nTraining configs:")
train_params.print_config(indent='\t')
log_update("\nTraining models")
# loop through the embedding paths and train each one
for embedding_path, details in all_embedding_paths.items():
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
try:
with open(embedding_path, "rb") as f:
embeddings = pickle.load(f)
except:
raise Exception(f"Cannot read embeddings from {embedding_path}")
# combine the embeddings and splits into one dataframe
splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'})
splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left')
for task in ['nucleus','cytoplasm','formation']:
log_update(f"\t\tTask: {task}")
train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task,
class1_thresh=train_params.CLASS1_THRESHOLDS[task],
n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD)
log_update(f"\nMaking summary figures:\n")
log_update(f"\tbar charts...")
os.makedirs(f"{output_dir}/figures",exist_ok=True)
make_all_final_bar_charts(output_dir)
log_update(f"\tDone.")
if __name__ == '__main__':
main() |