File size: 12,977 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

import logging
import warnings
from argparse import ArgumentParser
from contextlib import redirect_stdout
from datetime import datetime
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
from hydra.utils import instantiate
from omegaconf import OmegaConf
from sklearn.preprocessing import StandardScaler
from umap import UMAP

from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split
from ..processing.seq_tokenizer import SeqTokenizer
from ..utils.file import load
from ..utils.tcga_post_analysis_utils import Results_Handler
from ..utils.utils import (get_model, infer_from_pd,
                           prepare_inference_results_tcga,
                           update_config_with_inference_params)

logger = logging.getLogger(__name__)

warnings.filterwarnings("ignore")

def aggregate_ensemble_model(lev_dist_df:pd.DataFrame):
    '''
    This function aggregates the predictions of the ensemble model by choosing the model with the lowest and the highest NLD per query sequence.
    If the lowest NLD is lower than Novelty Threshold, then the model with the lowest NLD is chosen as the ensemble prediction.
    Otherwise, the model with the highest NLD is chosen as the ensemble prediction.
    '''
    #for every sequence, if at least one model scores an NLD < Novelty Threshold, then get the one with the least NLD as the ensemble prediction
    #otherwise, get the highest NLD.
    #get the minimum NLD per query sequence
    #remove the baseline model
    baseline_df = lev_dist_df[lev_dist_df['Model'] == 'Baseline'].reset_index(drop=True)
    lev_dist_df = lev_dist_df[lev_dist_df['Model'] != 'Baseline'].reset_index(drop=True)
    min_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmin().values]
    #get the maximum NLD per query sequence
    max_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmax().values]
    #choose between each row in min_lev_dist_df and max_lev_dist_df based on the value of Novelty Threshold
    novel_mask_df = min_lev_dist_df['NLD'] > min_lev_dist_df['Novelty Threshold']
    #get the rows where NLD is lower than Novelty Threshold
    min_lev_dist_df = min_lev_dist_df[~novel_mask_df.values]
    #get the rows where NLD is higher than Novelty Threshold
    max_lev_dist_df = max_lev_dist_df[novel_mask_df.values]
    #merge min_lev_dist_df and max_lev_dist_df
    ensemble_lev_dist_df = pd.concat([min_lev_dist_df,max_lev_dist_df])
    #add ensemble model
    ensemble_lev_dist_df['Model'] = 'Ensemble'
    #add ensemble_lev_dist_df to lev_dist_df
    lev_dist_df = pd.concat([lev_dist_df,ensemble_lev_dist_df,baseline_df])
    return lev_dist_df.reset_index(drop=True)


def read_inference_model_config(model:str,mc_or_sc,trained_on:str,path_to_models:str):
    transforna_folder = "TransfoRNA_ID"
    if trained_on == "full":
        transforna_folder = "TransfoRNA_FULL"

    model_path = f"{path_to_models}/{transforna_folder}/{mc_or_sc}/{model}/meta/hp_settings.yaml"
    cfg = OmegaConf.load(model_path)
    return cfg

def predict_transforna(sequences: List[str], model: str = "Seq-Rev", mc_or_sc:str='sub_class',\
                    logits_flag:bool = False,attention_flag:bool = False,\
                        similarity_flag:bool=False,n_sim:int=3,embedds_flag:bool = False, \
                            umap_flag:bool = False,trained_on:str='full',path_to_models:str='') -> pd.DataFrame:
    '''
    This function predicts the major class or sub class of a list of sequences using the TransfoRNA model.
    Additionaly, it can return logits, attention scores, similarity scores, gene embeddings or umap embeddings.

    Input:
        sequences: list of sequences to predict
        model: model to use for prediction
        mc_or_sc: models trained on major class or sub class
        logits_flag: whether to return logits
        attention_flag: whether to return attention scores (obtained from the self-attention layer)
        similarity_flag: whether to return explanatory/similar sequences in the training set
        n_sim: number of similar sequences to return
        embedds_flag: whether to return embeddings of the sequences
        umap_flag: whether to return umap embeddings
        trained_on: whether to use the model trained on the full dataset or the ID dataset
    Output:
        pd.DataFrame with the predictions
    '''
    #assers that only one flag is True
    assert sum([logits_flag,attention_flag,similarity_flag,embedds_flag,umap_flag]) <= 1, 'One option at most can be True'
    # capitalize the first letter of the model and the first letter after the -
    model = "-".join([word.capitalize() for word in model.split("-")])
    cfg = read_inference_model_config(model,mc_or_sc,trained_on,path_to_models)
    cfg = update_config_with_inference_params(cfg,mc_or_sc,trained_on,path_to_models)
    root_dir = Path(__file__).parents[1].absolute()

    with redirect_stdout(None):
        cfg, net = get_model(cfg, root_dir)
        #original_infer_pd might include seqs that are longer than input model. if so, infer_pd contains the trimmed sequences
        infer_pd = pd.Series(sequences, name="Sequences").to_frame()
        predicted_labels, logits, gene_embedds_df,attn_scores_pd,all_data, max_len, net,_  = infer_from_pd(cfg, net, infer_pd, SeqTokenizer,attention_flag)

        if model == 'Seq':
            gene_embedds_df = gene_embedds_df.iloc[:,:int(gene_embedds_df.shape[1]/2)]
    if logits_flag:
        cfg['log_logits'] = True
    prepare_inference_results_tcga(cfg, predicted_labels, logits, all_data, max_len)
    infer_pd = all_data["infere_rna_seq"]

    if logits_flag:
        logits_df = infer_pd.rename_axis("Sequence").reset_index()
        logits_cols = [col for col in infer_pd.columns if "Logits" in col]
        logits_df = infer_pd[logits_cols]
        logits_df.columns = pd.MultiIndex.from_tuples(logits_df.columns, names=["Logits", "Sub Class"])
        logits_df.columns = logits_df.columns.droplevel(0)
        return logits_df
    
    elif attention_flag:
        return attn_scores_pd

    elif embedds_flag:
        return gene_embedds_df

    else: #return table with predictions, entropy, threshold, is familiar
        #add aa predictions to infer_pd
        embedds_path = '/'.join(cfg['inference_settings']["model_path"].split('/')[:-2])+'/embedds'
        results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
        results.get_knn_model()
        lv_threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"]
        logger.info(f'computing levenstein distance for the inference set')
        #prepare infer split
        gene_embedds_df.columns = results.embedds_cols[:len(gene_embedds_df.columns)]
        #add index of gene_embedds_df to be a column with name results.seq_col
        gene_embedds_df[results.seq_col] = gene_embedds_df.index
        #set gene_embedds_df as the new infer split
        results.splits_df_dict['infer_df'] = gene_embedds_df


        _,_,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'infer',num_neighbors=n_sim)

        if similarity_flag:
            #create df
            sim_df = pd.DataFrame()
            #populate query sequences and duplicate them n times
            sequences = gene_embedds_df.index.tolist()
            #duplicate each sequence n_sim times
            sequences_duplicated = [seq for seq in sequences for _ in range(n_sim)]
            sim_df['Sequence'] = sequences_duplicated
            #assign top_5_seqs list to df column
            sim_df[f'Explanatory Sequence'] = top_n_seqs
            sim_df['NLD'] = lev_dist
            sim_df['Explanatory Label'] = top_n_labels
            sim_df['Novelty Threshold'] = lv_threshold
            #for every query sequence, order the NLD in a increasing order
            sim_df = sim_df.sort_values(by=['Sequence','NLD'],ascending=[False,True])
            return sim_df
        
        logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}')
        #for every n_sim elements in the list, get the smallest levenstein distance 
        lv_dist_closest = [min(lev_dist[i:i+n_sim]) for i in range(0,len(lev_dist),n_sim)]
        top_n_labels_closest = [top_n_labels[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)]
        top_n_seqs_closest = [top_n_seqs[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)]
        infer_pd['Is Familiar?'] = [True if lv<lv_threshold else False for lv in lv_dist_closest]

        if umap_flag:
            #compute umap
            logger.info(f'computing umap for the inference set')
            gene_embedds_df = gene_embedds_df.drop(results.seq_col,axis=1)
            umap = UMAP(n_components=2,random_state=42)
            scaled_embedds = StandardScaler().fit_transform(gene_embedds_df.values)
            gene_embedds_df = pd.DataFrame(umap.fit_transform(scaled_embedds),columns=['UMAP1','UMAP2'])
            gene_embedds_df['Net-Label'] = infer_pd['Net-Label'].values
            gene_embedds_df['Is Familiar?'] = infer_pd['Is Familiar?'].values
            gene_embedds_df['Explanatory Label'] = top_n_labels_closest
            gene_embedds_df['Explanatory Sequence'] = top_n_seqs_closest
            gene_embedds_df['Sequence'] = infer_pd.index
            return gene_embedds_df

        #override threshold
        infer_pd['Novelty Threshold'] = lv_threshold
        infer_pd['NLD'] = lv_dist_closest
        infer_pd['Explanatory Label'] = top_n_labels_closest
        infer_pd['Explanatory Sequence'] = top_n_seqs_closest
        infer_pd = infer_pd.round({"NLD": 2, "Novelty Threshold": 2})
        logger.info(f'num of new hico based on levenstein distance is {np.sum(infer_pd["Is Familiar?"])}')
        return infer_pd.rename_axis("Sequence").reset_index()

def predict_transforna_all_models(sequences: List[str], mc_or_sc:str = 'sub_class',logits_flag: bool = False, attention_flag: bool = False,\
        similarity_flag: bool = False, n_sim:int = 3,
        embedds_flag:bool=False, umap_flag:bool = False, trained_on:str="full",path_to_models:str='') -> pd.DataFrame:
    """
    Predicts the labels of the sequences using all the models available in the transforna package.
    If non of the flags are true, it constructs and aggrgates the output of the ensemble model.
    
    Input:
        sequences: list of sequences to predict
        mc_or_sc: models trained on major class or sub class
        logits_flag: whether to return logits
        attention_flag: whether to return attention scores (obtained from the self-attention layer)
        similarity_flag: whether to return explanatory/similar sequences in the training set
        n_sim: number of similar sequences to return
        embedds_flag: whether to return embeddings of the sequences
        umap_flag: whether to return umap embeddings
        trained_on: whether to use the model trained on the full dataset or the ID dataset
    Output:
        df: dataframe with the predictions
    """
    now = datetime.now()
    before_time = now.strftime("%H:%M:%S")
    models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"]
    if similarity_flag or embedds_flag: #remove baseline, takes long time
        models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"]
    if attention_flag: #remove single based transformer models
        models = ["Seq", "Seq-Struct", "Seq-Rev"]
    df = None
    for model in models:
        logger.info(model)
        df_ = predict_transforna(sequences, model, mc_or_sc,logits_flag,attention_flag,similarity_flag,n_sim,embedds_flag,umap_flag,trained_on=trained_on,path_to_models = path_to_models)
        df_["Model"] = model
        df = pd.concat([df, df_], axis=0)
    #aggregate ensemble model if not of the flags are true
    if not logits_flag and not attention_flag and not similarity_flag and not embedds_flag and not umap_flag:
        df = aggregate_ensemble_model(df)

    now = datetime.now()
    after_time = now.strftime("%H:%M:%S")
    delta_time = datetime.strptime(after_time, "%H:%M:%S") - datetime.strptime(before_time, "%H:%M:%S")
    logger.info(f"Time taken: {delta_time}")
    
    return df


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("sequences", nargs="+")
    parser.add_argument("--logits_flag", nargs="?", const = True,default=False)
    parser.add_argument("--attention_flag", nargs="?", const = True,default=False)
    parser.add_argument("--similarity_flag", nargs="?", const = True,default=False)
    parser.add_argument("--n_sim", nargs="?", const = 3,default=3)
    parser.add_argument("--embedds_flag", nargs="?", const = True,default=False)
    parser.add_argument("--trained_on", nargs="?", const = True,default="full")
    predict_transforna_all_models(**vars(parser.parse_args()))