Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
import logging
import math
import os
import warnings
from random import randint
import numpy as np
import pandas as pd
from numpy.lib.stride_tricks import as_strided
from omegaconf import DictConfig, open_dict
from ..utils import energy
from ..utils.file import save
logger = logging.getLogger(__name__)
class SeqTokenizer:
'''
This class should contain functions that other data specific classes should inherit from.
'''
def __init__(self,seqs_dot_bracket_labels: pd.DataFrame, config: DictConfig):
self.seqs_dot_bracket_labels = seqs_dot_bracket_labels.reset_index(drop=True)
#shuffle
if not config["inference"]:
self.seqs_dot_bracket_labels = self.seqs_dot_bracket_labels\
.sample(frac=1)\
.reset_index(drop=True)
#get input of model
self.model_input = config["model_config"].model_input
# set max length to be <= 2 stds of distribtion of lengths
if config["train_config"].filter_seq_length:
self.get_outlier_length_threshold()
self.limit_seqs_to_range()
else:
self.max_length = self.seqs_dot_bracket_labels['Sequences'].str.len().max()
self.min_length = 0
with open_dict(config):
config["model_config"]["max_length"] = np.int64(self.max_length).item()
config["model_config"]["min_length"] = np.int64(self.min_length).item()
self.window = config["model_config"].window
self.tokens_len = math.ceil(self.max_length / self.window)
if config["model_config"].tokenizer in ["overlap", "overlap_multi_window"]:
self.tokens_len = int(self.max_length - (config["model_config"].window - 1))
self.tokenizer = config["model_config"].tokenizer
self.seq_len_dist = self.seqs_dot_bracket_labels['Sequences'].str.len().value_counts()
#init tokens dict
self.seq_tokens_ids_dict = {}
self.second_input_tokens_ids_dict = {}
#get and set number of labels in config to be used later by the model
config["model_config"].num_classes = len(self.seqs_dot_bracket_labels['Labels'].unique())
self.set_class_attr()
def get_outlier_length_threshold(self):
lengths_arr = self.seqs_dot_bracket_labels['Sequences'].str.len()
mean = np.mean(lengths_arr)
standard_deviation = np.std(lengths_arr)
distance_from_mean = abs(lengths_arr - mean)
in_distribution = distance_from_mean < 2 * standard_deviation
inlier_lengths = np.sort(lengths_arr[in_distribution].unique())
self.max_length = int(np.max(inlier_lengths))
self.min_length = int(np.min(inlier_lengths))
logger.info(f'maximum and minimum sequence length is set to: {self.max_length} and {self.min_length}')
return
def limit_seqs_to_range(self):
'''
Trimms seqs longer than maximum len and deletes seqs shorter than min length
'''
df = self.seqs_dot_bracket_labels
min_to_be_deleted = []
num_longer_seqs = sum(df['Sequences'].str.len()>self.max_length)
if num_longer_seqs:
logger.info(f"Number of sequences to be trimmed: {num_longer_seqs}")
for idx,seq in enumerate(df['Sequences']):
if len(seq) > self.max_length:
df['Sequences'].iloc[idx] = \
df['Sequences'].iloc[idx][:self.max_length]
elif len(seq) < self.min_length:
#deleted sequence indices
min_to_be_deleted.append(str(idx))
#delete min sequences
if len(min_to_be_deleted):
df = df.drop(min_to_be_deleted).reset_index(drop=True)
logger.info(f"Number of sequences shroter sequences to be removed: {len(min_to_be_deleted)}")
self.seqs_dot_bracket_labels = df
def get_secondary_structure(self,sequences):
secondary = energy.fold_sequences(sequences.tolist())
return secondary['structure_37'].values
# function generating non overlapping tokens of a feature sample
def chunkstring_overlap(self, string, window):
return (
string[0 + i : window + i] for i in range(0, len(string) - window + 1, 1)
)
# function generating non overlapping tokens of a feature sample
def chunkstring_no_overlap(self, string, window):
return (string[0 + i : window + i] for i in range(0, len(string), window))
def tokenize_samples(self, window:int,sequences_to_be_tokenized:pd.DataFrame,inference:bool=False,tokenizer:str="overlap") -> np.ndarray:
"""
This function tokenizes rnas based on window(window)
with or without overlap according to the current tokenizer option.
In case of overlap:
example: Token :AACTAGA, window: 3
output: AAC,ACT,CTA,TAG,AGA
In case no_overlap:
example: Token :AACTAGA, window: 3
output: AAC,TAG,A
"""
# get feature tokens
if "overlap" in tokenizer:
feature_tokens_gen = list(
self.chunkstring_overlap(feature, window)
for feature in sequences_to_be_tokenized
)
elif tokenizer == "no_overlap":
feature_tokens_gen = list(
self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized
)
# get sample tokens and pad them
samples_tokenized = []
sample_token_ids = []
if not self.seq_tokens_ids_dict:
self.seq_tokens_ids_dict = {"pad": 0}
for gen in feature_tokens_gen:
sample_token_id = []
sample_token = list(gen)
sample_len = len(sample_token)
# append paddings
sample_token.extend(
["pad" for _ in range(int(self.tokens_len - sample_len))]
)
# convert tokens to ids
for token in sample_token:
# if token doesnt exist in dict, create one
if token not in self.seq_tokens_ids_dict:
if not inference:
id = len(self.seq_tokens_ids_dict.keys())
self.seq_tokens_ids_dict[token] = id
else:
#if new token found during inference, then select random token (considered as noise)
logger.warning(f"The sequence token: {token} was not seen previously by the model. Token will be replaced by a random token")
id = randint(1,len(self.seq_tokens_ids_dict.keys()) - 1)
token = self.seq_tokens_ids_dict[id]
# append id of token
sample_token_id.append(self.seq_tokens_ids_dict[token])
# append ids of tokenized sample
sample_token_ids.append(np.array(sample_token_id))
sample_token = np.array(sample_token)
samples_tokenized.append(sample_token)
return (np.array(samples_tokenized), np.array(sample_token_ids))
def tokenize_secondary_structure(self, window,sequences_to_be_tokenized,inference:bool=False,tokenizer= "overlap") -> np.ndarray:
"""
This function tokenizes rnas based on window(window)
with or without overlap according to the current tokenizer option.
In case of overlap:
example: Token :...()..., window: 3
output: ...,..(,.(),().,)..,...
In case no_overlap:
example: Token :...()..., window: 3
output: ...,().,..
"""
samples_tokenized = []
sample_token_ids = []
if not self.second_input_tokens_ids_dict:
self.second_input_tokens_ids_dict = {"pad": 0}
# get feature tokens
if "overlap" in tokenizer:
feature_tokens_gen = list(
self.chunkstring_overlap(feature, window)
for feature in sequences_to_be_tokenized
)
elif "no_overlap" == tokenizer:
feature_tokens_gen = list(
self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized
)
# get sample tokens and pad them
for seq_idx, gen in enumerate(feature_tokens_gen):
sample_token_id = []
sample_token = list(gen)
# convert tokens to ids
for token in sample_token:
# if token doesnt exist in dict, create one
if token not in self.second_input_tokens_ids_dict:
if not inference:
id = len(self.second_input_tokens_ids_dict.keys())
self.second_input_tokens_ids_dict[token] = id
else:
#if new token found during inference, then select random token (considered as noise)
warnings.warn(f"The secondary structure token: {token} was not seen previously by the model. Token will be replaced by a random token")
id = randint(1,len(self.second_input_tokens_ids_dict.keys()) - 1)
token = self.second_input_tokens_ids_dict[id]
# append id of token
sample_token_id.append(self.second_input_tokens_ids_dict[token])
# append ids of tokenized sample
sample_token_ids.append(sample_token_id)
samples_tokenized.append(sample_token)
#append pads
#max length is number of different temp used* max token len PLUS the concat token
# between two secondary structures represented at two diff temperatures
self.second_input_token_len = self.tokens_len
for seq_idx, token in enumerate(sample_token_ids):
sample_len = len(token)
sample_token_ids[seq_idx].extend(
[self.second_input_tokens_ids_dict["pad"] for _ in range(int(self.second_input_token_len - sample_len))]
)
samples_tokenized[seq_idx].extend(
["pad" for _ in range(int(self.second_input_token_len - sample_len))]
)
sample_token_ids[seq_idx] = np.array(sample_token_ids[seq_idx])
samples_tokenized[seq_idx] = np.array(samples_tokenized[seq_idx])
# save vocab
return (np.array(samples_tokenized), np.array(sample_token_ids))
def set_class_attr(self):
#set seq,struct and exp and labels
self.seq = self.seqs_dot_bracket_labels["Sequences"]
if 'struct' in self.model_input:
self.struct = self.seqs_dot_bracket_labels["Secondary"]
self.labels = self.seqs_dot_bracket_labels['Labels']
def prepare_multi_idx_pd(self,num_coln,pd_name,pd_value):
iterables = [[pd_name], np.arange(num_coln)]
index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"])
return pd.DataFrame(columns=index, data=pd_value)
def phase_sequence(self,sample_token_ids):
phase0 = sample_token_ids[:,::2]
phase1 = sample_token_ids[:,1::2]
#in case max_length is an odd number phase 0 will be 1 entry larger than phase 1 @ dim=1
if phase0.shape!= phase1.shape:
phase1 = np.concatenate([phase1,np.zeros(phase1.shape[0])[...,np.newaxis]],axis=1)
sample_token_ids = phase0
return sample_token_ids,phase1
def custom_roll(self,arr, n_shifts_per_row):
'''
shifts each row of a numpy array according to n_shifts_per_row
'''
m = np.asarray(n_shifts_per_row)
arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy`
strd_0, strd_1 = arr_roll.strides
n = arr.shape[1]
result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))
return result[np.arange(arr.shape[0]), (n-m)%n]
def save_token_dicts(self):
#save token dicts
save(data = self.second_input_tokens_ids_dict,path = os.getcwd()+'/second_input_tokens_ids_dict')
save(data = self.seq_tokens_ids_dict,path = os.getcwd()+'/seq_tokens_ids_dict')
def get_tokenized_data(self,inference:bool=False):
#tokenize sequences
samples_tokenized,sample_token_ids = self.tokenize_samples(self.window,self.seq,inference)
logger.info(f'Vocab size for primary sequences: {len(self.seq_tokens_ids_dict.keys())}')
logger.info(f'Vocab size for secondary structure: {len(self.second_input_tokens_ids_dict.keys())}')
logger.info(f'Number of sequences used for tokenization: {samples_tokenized.shape[0]}')
#tokenize struct if used
if "comp" in self.model_input:
#get compliment of self.seq
self.seq_comp = []
for feature in self.seq:
feature = feature.replace('A','%temp%').replace('T','A')\
.replace('C','%temp2%').replace('G','C')\
.replace('%temp%','T').replace('%temp2%','G')
self.seq_comp.append(feature)
#store seq_tokens_ids_dict
self.seq_tokens_ids_dict_temp = self.seq_tokens_ids_dict
self.seq_tokens_ids_dict = None
#tokenize compliment
_,seq_comp_str_token_ids = self.tokenize_samples(self.window,self.seq_comp,inference)
sec_input_value = seq_comp_str_token_ids
#store second input seq_tokens_ids_dict
self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict
self.seq_tokens_ids_dict = self.seq_tokens_ids_dict_temp
#tokenize struct if used
if "struct" in self.model_input:
_,sec_str_token_ids = self.tokenize_secondary_structure(self.window,self.struct,inference)
sec_input_value = sec_str_token_ids
#add seq-seq if used
if "seq-seq" in self.model_input:
sample_token_ids,sec_input_value = self.phase_sequence(sample_token_ids)
self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict
#in case of baseline or only "seq", the second input is dummy
#TODO:: refactor transforna to accept models with a single input (baseline and "seq")
# without occupying unnecessary resources
if "seq-rev" in self.model_input or "baseline" in self.model_input or self.model_input == 'seq':
sample_token_ids_rev = sample_token_ids[:,::-1]
n_zeros = np.count_nonzero(sample_token_ids_rev==0, axis=1)
sec_input_value = self.custom_roll(sample_token_ids_rev, -n_zeros)
self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict
seqs_length = list(sum(sample_token_ids.T !=0))
labels_df = self.prepare_multi_idx_pd(1,"Labels",self.labels.values)
tokens_id_df = self.prepare_multi_idx_pd(sample_token_ids.shape[1],"tokens_id",sample_token_ids)
tokens_df = self.prepare_multi_idx_pd(samples_tokenized.shape[1],"tokens",samples_tokenized)
sec_input_df = self.prepare_multi_idx_pd(sec_input_value.shape[1],'second_input',sec_input_value)
seqs_length_df = self.prepare_multi_idx_pd(1,"seqs_length",seqs_length)
all_df = labels_df.join(tokens_df).join(tokens_id_df).join(sec_input_df).join(seqs_length_df)
#save token dicts
self.save_token_dicts()
return all_df