File size: 3,408 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import logging
from typing import Dict

from anndata import AnnData
from omegaconf import DictConfig, OmegaConf

from ..callbacks.metrics import accuracy_score
from ..novelty_prediction.id_vs_ood_entropy_clf import compute_entropies
from ..novelty_prediction.id_vs_ood_nld_clf import compute_nlds
from ..processing.augmentation import DataAugmenter
from ..processing.seq_tokenizer import SeqTokenizer
from ..processing.splitter import *
from ..processing.splitter import DataSplitter
from ..score.score import (compute_score_benchmark, compute_score_tcga,
                           infere_additional_test_data)
from ..utils.file import load, save
from ..utils.utils import (instantiate_predictor, prepare_data_benchmark,
                           set_seed_and_device, sync_skorch_with_config)

logger = logging.getLogger(__name__)

def compute_cv(cfg:DictConfig,path:str,output_dir:str):

    summary_pd = pd.DataFrame(index=np.arange(cfg["num_replicates"]),columns = ['B. Acc','Dur'])
    for seed_no in range(cfg["num_replicates"]):
        logger.info(f"Currently training replicate {seed_no}")
        cfg["seed"] = seed_no
        test_score,net = train(cfg,path=path,output_dir=output_dir)                
        convrg_epoch = np.where(net.history[:,'val_acc_best'])[0][-1]
        convrg_dur = sum(net.history[:,'dur'][:convrg_epoch+1])
        summary_pd.iloc[seed_no] = [test_score,convrg_dur]
    
    save(path=path+'/summary_pd',data=summary_pd)
    
    return

def train(cfg:Dict= None,path:str = None,output_dir:str = None):
    if cfg['tensorboard']:
        from ..callbacks.tbWriter import writer
    #set seed
    set_seed_and_device(cfg["seed"],cfg["device_number"])

    dataset = load(cfg["train_config"].dataset_path_train)
    
    if isinstance(dataset,AnnData):
        dataset = dataset.var
    else:
        dataset.set_index('sequence',inplace=True)

    #instantiate dataset class
    
    if cfg["task"] in ["premirna","sncrna"]:
        tokenizer = SeqTokenizer(dataset,cfg)
        test_ad = load(cfg["train_config"].dataset_path_test)
        #prepare data for training and inference
        all_data = prepare_data_benchmark(tokenizer,test_ad,cfg)
    else: 
        df = DataAugmenter(dataset,cfg).get_augmented_df()
        tokenizer = SeqTokenizer(df,cfg)
        all_data = DataSplitter(tokenizer,cfg).prepare_data_tcga()

    #sync skorch config with params in train and model config
    sync_skorch_with_config(cfg["model"]["skorch_model"],cfg)

     # instantiate skorch model
    net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path)
    
    #train
    #if train_split is none, then discard valid_ds
    net.fit(all_data["train_data"],all_data["train_labels_numeric"],all_data["valid_ds"])
    
    #log train and model HP to curr run dir    
    save(data=OmegaConf.to_container(cfg, resolve=True),path=path+'/meta/hp_settings.yaml')

    #compute scores and log embedds
    if cfg['task'] == 'tcga':
        test_score = compute_score_tcga(net, all_data,path,cfg)
        compute_nlds(output_dir)
        compute_entropies(output_dir)
    else:
        test_score = compute_score_benchmark(net, path,all_data,accuracy_score,cfg)
    #only for premirna
    if "additional_testset" in all_data:
        infere_additional_test_data(net,all_data["additional_testset"])



    if cfg['tensorboard']:
        writer.close()
    return test_score,net