gzhong's picture
Upload folder using huggingface_hub
7718235 verified
import os
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from utils import load, load_rows_by_numbers
from utils import seqs_to_onehot, get_wt_seq, read_fasta, seq2effect
from predictors.base_predictors import BaseRegressionPredictor, BaseGPPredictor
from predictors.hmm_predictors import HMMPredictor
class BaseUniRepPredictor(BaseRegressionPredictor):
"""UniRep representation + regression."""
def __init__(self, dataset_name, rep_name, reg_coef=1.0, **kwargs):
super(BaseUniRepPredictor, self).__init__(
dataset_name, reg_coef, Ridge)
self.load_rep(dataset_name, rep_name)
def load_rep(self, dataset_name, rep_name):
self.rep_path = os.path.join('inference', dataset_name,
'unirep', rep_name, f'avg_hidden.npy*')
self.seq_path = os.path.join('inference', dataset_name,
'unirep', rep_name, f'seqs.npy')
#self.features = load(self.rep_path)
self.seqs = np.loadtxt(self.seq_path, dtype=str, delimiter=' ')
self.seq2id = dict(zip(self.seqs, range(len(self.seqs))))
def seq2feat(self, seqs):
"""Look up representation by sequence."""
ids = [self.seq2id[s] for s in seqs]
return load_rows_by_numbers(self.rep_path, ids)
class GUniRepRegressionPredictor(BaseUniRepPredictor):
"""Global UniRep + Ridge regression."""
def __init__(self, dataset_name, **kwargs):
super(GUniRepRegressionPredictor, self).__init__(
dataset_name, 'global', **kwargs)
class EUniRepRegressionPredictor(BaseUniRepPredictor):
"""Evotuned UniRep + Ridge regression."""
def __init__(self, dataset_name, rep_name='uniref100', **kwargs):
super(EUniRepRegressionPredictor, self).__init__(
dataset_name, rep_name, **kwargs)
class UniRepLLPredictor(BaseUniRepPredictor):
"""UniRep log likelihood."""
def __init__(self, dataset_name, rep_name, reg_coef=1e-8, **kwargs):
super(UniRepLLPredictor, self).__init__(
dataset_name, rep_name, reg_coef=reg_coef, **kwargs)
self.loss_path = os.path.join('inference', dataset_name,
'unirep', rep_name, f'loss.npy*')
def seq2feat(self, seqs):
"""Look up log likelihood by sequence."""
ids = [self.seq2id[s] for s in seqs]
return -load_rows_by_numbers(self.loss_path, ids)
def predict_unsupervised(self, seqs):
return self.seq2feat(seqs).ravel()
class GUniRepLLPredictor(UniRepLLPredictor):
def __init__(self, dataset_name, **kwargs):
super(GUniRepLLPredictor, self).__init__(dataset_name,
'global', **kwargs)
class EUniRepLLPredictor(UniRepLLPredictor):
def __init__(self, dataset_name, rep_name='uniref100', **kwargs):
super(EUniRepLLPredictor, self).__init__(dataset_name,
rep_name, **kwargs)