|
import os |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.linear_model import Ridge, Lasso, LinearRegression |
|
|
|
from utils import load, load_rows_by_numbers |
|
from utils import seqs_to_onehot, get_wt_seq, read_fasta, seq2effect |
|
from predictors.base_predictors import BaseRegressionPredictor, BaseGPPredictor |
|
from predictors.hmm_predictors import HMMPredictor |
|
|
|
|
|
class BaseUniRepPredictor(BaseRegressionPredictor): |
|
"""UniRep representation + regression.""" |
|
|
|
def __init__(self, dataset_name, rep_name, reg_coef=1.0, **kwargs): |
|
super(BaseUniRepPredictor, self).__init__( |
|
dataset_name, reg_coef, Ridge) |
|
self.load_rep(dataset_name, rep_name) |
|
|
|
def load_rep(self, dataset_name, rep_name): |
|
self.rep_path = os.path.join('inference', dataset_name, |
|
'unirep', rep_name, f'avg_hidden.npy*') |
|
self.seq_path = os.path.join('inference', dataset_name, |
|
'unirep', rep_name, f'seqs.npy') |
|
|
|
self.seqs = np.loadtxt(self.seq_path, dtype=str, delimiter=' ') |
|
self.seq2id = dict(zip(self.seqs, range(len(self.seqs)))) |
|
|
|
def seq2feat(self, seqs): |
|
"""Look up representation by sequence.""" |
|
ids = [self.seq2id[s] for s in seqs] |
|
return load_rows_by_numbers(self.rep_path, ids) |
|
|
|
|
|
class GUniRepRegressionPredictor(BaseUniRepPredictor): |
|
"""Global UniRep + Ridge regression.""" |
|
|
|
def __init__(self, dataset_name, **kwargs): |
|
super(GUniRepRegressionPredictor, self).__init__( |
|
dataset_name, 'global', **kwargs) |
|
|
|
|
|
class EUniRepRegressionPredictor(BaseUniRepPredictor): |
|
"""Evotuned UniRep + Ridge regression.""" |
|
|
|
def __init__(self, dataset_name, rep_name='uniref100', **kwargs): |
|
super(EUniRepRegressionPredictor, self).__init__( |
|
dataset_name, rep_name, **kwargs) |
|
|
|
|
|
class UniRepLLPredictor(BaseUniRepPredictor): |
|
"""UniRep log likelihood.""" |
|
|
|
def __init__(self, dataset_name, rep_name, reg_coef=1e-8, **kwargs): |
|
super(UniRepLLPredictor, self).__init__( |
|
dataset_name, rep_name, reg_coef=reg_coef, **kwargs) |
|
self.loss_path = os.path.join('inference', dataset_name, |
|
'unirep', rep_name, f'loss.npy*') |
|
|
|
def seq2feat(self, seqs): |
|
"""Look up log likelihood by sequence.""" |
|
ids = [self.seq2id[s] for s in seqs] |
|
return -load_rows_by_numbers(self.loss_path, ids) |
|
|
|
def predict_unsupervised(self, seqs): |
|
return self.seq2feat(seqs).ravel() |
|
|
|
|
|
class GUniRepLLPredictor(UniRepLLPredictor): |
|
def __init__(self, dataset_name, **kwargs): |
|
super(GUniRepLLPredictor, self).__init__(dataset_name, |
|
'global', **kwargs) |
|
|
|
|
|
class EUniRepLLPredictor(UniRepLLPredictor): |
|
def __init__(self, dataset_name, rep_name='uniref100', **kwargs): |
|
super(EUniRepLLPredictor, self).__init__(dataset_name, |
|
rep_name, **kwargs) |
|
|