Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
import logging
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import (compute_class_weight,
compute_sample_weight)
from skorch.dataset import Dataset
from skorch.helper import predefined_split
from ..utils.energy import fold_sequences
from ..utils.file import load, save
from ..utils.utils import (revert_seq_tokenization,
update_config_with_dataset_params_benchmark,
update_config_with_dataset_params_tcga)
from anndata import AnnData
logger = logging.getLogger(__name__)
class DataSplitter:
def __init__(self,tokenizer,configs):
self.tokenizer = tokenizer
self.configs = configs
self.seed = configs.seed
self.trained_on = configs.trained_on
self.device = configs["train_config"].device
self.splits_df_dict = {}
self.min_num_samples_per_class = 10
def convert_to_tensor(self,in_arr,convert_type):
tensor_dtype = torch.long if convert_type == int else torch.float
return torch.tensor(
np.array(in_arr, dtype=convert_type),
dtype=tensor_dtype,
).to(device=self.device)
def get_features_per_split(self):
model_input_cols = ['tokens_id','second_input','seqs_length']
features_dict = {}
for split_df in self.splits_df_dict.keys():
split_data = self.convert_to_tensor(self.splits_df_dict[split_df][model_input_cols].values,convert_type=float)
split = '_'.join(split_df.split('_')[:-1])
features_dict[f'{split}_data'] = split_data
return features_dict
def append_sample_weights(self,splits_features_dict):
for split_df in self.splits_df_dict.keys():
if split_df in ['train_df','valid_df','test_df']:
split_weights = self.convert_to_tensor(compute_sample_weight('balanced',self.splits_df_dict[split_df]['Labels'][0]),convert_type=float)
else:
split_weights = self.convert_to_tensor(np.ones(self.splits_df_dict[split_df].shape[0]),convert_type=float)
split = '_'.join(split_df.split('_')[:-1])
splits_features_dict[f'{split}_data'] = torch.cat([splits_features_dict[f'{split}_data'],split_weights[:,None]],dim=1)
return
def get_labels_per_split(self):
#encode labels
enc = LabelEncoder()
enc.fit(self.splits_df_dict["train_df"]['Labels'])
#save mapping dict to config
self.configs["model_config"].class_mappings = enc.classes_.tolist()
labels_dict = {}
labels_numeric_dict = {}
for split_df in self.splits_df_dict.keys():
split = '_'.join(split_df.split('_')[:-1])
split_labels = self.splits_df_dict[split_df]['Labels']
if split_df in ['train_df','valid_df','test_df']:
split_labels_numeric = self.convert_to_tensor(enc.transform(split_labels), convert_type=int)
else:
split_labels_numeric = self.convert_to_tensor(np.zeros((split_labels.shape[0])), convert_type=int)
labels_dict[f'{split}_labels'] = split_labels
labels_numeric_dict[f'{split}_labels_numeric'] = split_labels_numeric
#compute class weight
class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(labels_dict['train_labels']),y=labels_dict['train_labels'][0].values)
#omegaconfig does not support float64 as datatype so conversion to str is done
# and reconversion is done in criterion
self.configs['model_config'].class_weights = [str(x) for x in list(class_weights)]
return labels_dict | labels_numeric_dict
def get_seqs_per_split(self):
rna_seq_dict = {}
for split_df in self.splits_df_dict.keys():
split = '_'.join(split_df.split('_')[:-1])
rna_seq_dict[f'{split}_rna_seq'] = revert_seq_tokenization(self.splits_df_dict[split_df]["tokens"],self.configs)
return rna_seq_dict
def duplicate_fewer_classes(self,df):
#get quantity of each class and append it as a column
df["Quantity",'0'] = df["Labels"].groupby([0])[0].transform("count")
frequent_samples_df = df[df["Quantity",'0'] >= self.min_num_samples_per_class].reset_index(drop=True)
fewer_samples_df = df[df["Quantity",'0'] < self.min_num_samples_per_class].reset_index(drop=True)
unique_fewer_samples_df = fewer_samples_df.drop_duplicates(subset=[('Labels',0)], keep="last")
unique_fewer_samples_df['Quantity','0'] -= self.min_num_samples_per_class
unique_fewer_samples_df['Quantity','0'] = unique_fewer_samples_df['Quantity','0'].abs()
repeated_fewer_samples_df = unique_fewer_samples_df.loc[unique_fewer_samples_df.index.repeat(unique_fewer_samples_df.Quantity['0'])]
repeated_fewer_samples_df = repeated_fewer_samples_df.reset_index(drop=True)
df = frequent_samples_df.append(repeated_fewer_samples_df).append(fewer_samples_df).reset_index(drop=True)
df.drop(columns = ['Quantity'],inplace=True)
return df
def remove_fewer_samples(self,data_df):
if 'sub_class' in self.configs['model_config']['clf_target']:
counts = data_df['Labels'].value_counts()
fewer_class_ids = counts[counts < self.min_num_samples_per_class].index
fewer_class_labels = [i[0] for i in fewer_class_ids]
elif 'major_class' in self.configs['model_config']['clf_target']:
#insure that major classes are the same as the one used when training for sub_class
#this is done for performance comparisons to be valid
#otherwise major class models would be trained on more major classes than sub_class models
tcga_df = load(self.configs['train_config'].dataset_path_train)
#only keep hico
tcga_df = tcga_df[tcga_df['hico'] == True]
if isinstance(tcga_df,AnnData):
tcga_df = tcga_df.var
#get subclass_name with samples higher than self.min_num_samples_per_class
counts = tcga_df['subclass_name'].value_counts()
all_subclasses = tcga_df['subclass_name'].unique()
selected_subclasses = counts[counts >= self.min_num_samples_per_class].index
#convert subclass_name to major_class
subclass_to_major_class_dict = load(self.configs['train_config'].mapping_dict_path)
all_major_classes = list(set([subclass_to_major_class_dict[sub_class] for sub_class in all_subclasses]))
selected_major_classes = list(set([subclass_to_major_class_dict[sub_class] for sub_class in selected_subclasses]))
fewer_class_labels = list(set(all_major_classes) - set(selected_major_classes))
#remove samples with major_class not in major_classes
fewer_samples_per_class_df = data_df.loc[data_df['Labels'].isin(fewer_class_labels).values, :]
fewer_ids = data_df.index.isin(fewer_samples_per_class_df.index)
data_df = data_df[~fewer_ids]
return fewer_samples_per_class_df,data_df
def split_tcga(self,data_df):
#remove artificial_affix
artificial_df = data_df.loc[data_df['Labels'][0].isin(['random','recombined','artificial_affix'])]
art_ids = data_df.index.isin(artificial_df.index)
data_df = data_df[~art_ids]
data_df = data_df.reset_index(drop=True)
#remove no annotations
no_annotaton_df = data_df.loc[data_df['Labels'].isnull().values]
n_a_ids = data_df.index.isin(no_annotaton_df.index)
data_df = data_df[~n_a_ids].reset_index(drop=True)
no_annotaton_df = no_annotaton_df.reset_index(drop=True)
if self.trained_on == 'full':
#duplication is done to ensure as other wise train_test_split will fail
data_df = self.duplicate_fewer_classes(data_df)
ood_dict = {}
else:
ood_df,data_df = self.remove_fewer_samples(data_df)
ood_dict = {"ood_df":ood_df}
#split data
train_df,valid_test_df = train_test_split(data_df,stratify=data_df["Labels"],train_size=0.8,random_state=self.seed)
if self.trained_on == 'id':
valid_df,test_df = train_test_split(valid_test_df,stratify=valid_test_df["Labels"],train_size=0.5,random_state=self.seed)
else:
#we need to use all n sequences in the training set, however, unseen samples should be gathered for training novelty prediction,
#otherwise NLD for test would be zero
#remove one sample from each class to test_df
test_df = valid_test_df.drop_duplicates(subset=[('Labels',0)], keep="last")
test_ids = valid_test_df.index.isin(test_df.index)
valid_df = valid_test_df[~test_ids].reset_index(drop=True)
train_df = train_df.append(valid_df).reset_index(drop=True)
self.splits_df_dict = {"train_df":train_df,"valid_df":valid_df,"test_df":test_df,"artificial_df":artificial_df,"no_annotation_df":no_annotaton_df} | ood_dict
def prepare_data_tcga(self):
"""
This function recieves tokenizer and prepares the data in a format suitable for training
It also set default parameters in the config that cannot be known until preprocessing step
is done.
"""
all_data_df = self.tokenizer.get_tokenized_data()
#split data
self.split_tcga(all_data_df)
num_samples = self.splits_df_dict['train_df'].shape[0]
num_classes = len(self.splits_df_dict['train_df'].Labels.value_counts()[self.splits_df_dict['train_df'].Labels.value_counts()>0])
#log
logger.info(f'Training with {num_classes} classes and {num_samples} samples')
#get features, labels, and seqs per split
splits_features_dict = self.get_features_per_split()
self.append_sample_weights(splits_features_dict)
splits_labels_dict = self.get_labels_per_split()
splits_seqs_dict = self.get_seqs_per_split()
#prepare validation set for skorch
valid_ds = Dataset(splits_features_dict["valid_data"],splits_labels_dict["valid_labels_numeric"])
valid_ds = predefined_split(valid_ds)
#combine all dicts
all_data = splits_features_dict | splits_labels_dict | splits_seqs_dict | \
{"valid_ds":valid_ds}
###update self.configs
update_config_with_dataset_params_tcga(self.tokenizer,all_data_df,self.configs)
self.configs["model_config"].num_classes = len(all_data['train_labels'][0].unique())
self.configs["train_config"].batch_per_epoch = int(all_data["train_data"].shape[0]\
/self.configs["train_config"].batch_size)
return all_data