|
import os |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.linear_model import Ridge, Lasso, LinearRegression |
|
|
|
from utils import seqs_to_onehot, read_fasta, load_rows_by_numbers |
|
from predictors.base_predictors import BaseRegressionPredictor, BaseGPPredictor |
|
|
|
|
|
class ESMPredictor(BaseRegressionPredictor): |
|
"""ESM likelihood as features in regression.""" |
|
|
|
def __init__(self, dataset_name, rep_name, reg_coef=1e-8, path_prefix='', |
|
**kwargs): |
|
super(ESMPredictor, self).__init__(dataset_name, reg_coef, Ridge) |
|
|
|
seqs_path = path_prefix + os.path.join('data', dataset_name, 'seqs.fasta') |
|
seqs = read_fasta(seqs_path) |
|
id2seq = pd.Series(index=np.arange(len(seqs)), data=seqs, name='seq') |
|
|
|
esm_data_path = path_prefix + os.path.join('inference', dataset_name, |
|
'esm', rep_name, 'pll.csv') |
|
ll = pd.read_csv(esm_data_path, index_col=0) |
|
ll['id'] = ll.index.to_series().apply( |
|
lambda x: int(x.replace('id_', ''))) |
|
ll = ll.join(id2seq, on='id', how='left') |
|
self.seq2score_dict = dict(zip(ll.seq, ll.pll)) |
|
|
|
def seq2score(self, seqs): |
|
scores = np.array([self.seq2score_dict.get(s, 0.0) for s in seqs]) |
|
return scores |
|
|
|
def seq2feat(self, seqs): |
|
return self.seq2score(seqs)[:, None] |
|
|
|
def predict_unsupervised(self, seqs): |
|
return self.seq2score(seqs) |
|
|
|
|
|
class GlobalESMPredictor(ESMPredictor): |
|
def __init__(self, dataset_name, **kwargs): |
|
super(GlobalESMPredictor, self).__init__( |
|
dataset_name, 'global', **kwargs) |
|
|
|
|
|
class EvotunedESMPredictor(ESMPredictor): |
|
def __init__(self, dataset_name, rep_name='uniref100', **kwargs): |
|
super(EvotunedESMPredictor, self).__init__( |
|
dataset_name, rep_name, **kwargs) |
|
|
|
|
|
class ESMRegressionPredictor(BaseRegressionPredictor): |
|
"""Regression on ESM representation.""" |
|
|
|
def __init__(self, dataset_name, rep_name, reg_coef=1.0, **kwargs): |
|
super(ESMRegressionPredictor, self).__init__(dataset_name, reg_coef, |
|
Ridge, **kwargs) |
|
self.load_rep(dataset_name, rep_name) |
|
|
|
def load_rep(self, dataset_name, rep_name): |
|
self.rep_path = os.path.join('inference', dataset_name, 'esm', |
|
rep_name, 'rep.npy*') |
|
self.seq_path = os.path.join('data', dataset_name, 'seqs.fasta') |
|
self.seqs = read_fasta(self.seq_path) |
|
self.seq2id = dict(zip(self.seqs, range(len(self.seqs)))) |
|
|
|
def seq2feat(self, seqs): |
|
"""Look up representation by sequences.""" |
|
ids = [self.seq2id[s] for s in seqs] |
|
return load_rows_by_numbers(self.rep_path, ids) |
|
|
|
|
|
class EvotunedESMRegressionPredictor(ESMRegressionPredictor): |
|
def __init__(self, dataset_name, **kwargs): |
|
super(EvotunedESMRegressionPredictor, self).__init__(dataset_name, |
|
'uniref100', **kwargs) |
|
|
|
|
|
class GlobalESMRegressionPredictor(ESMRegressionPredictor): |
|
def __init__(self, dataset_name, **kwargs): |
|
super(GlobalESMRegressionPredictor, self).__init__(dataset_name, |
|
'global', **kwargs) |
|
|
|
|
|
|