gzhong's picture
Upload folder using huggingface_hub
7718235 verified
import os
from Bio import SeqIO
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import MultipleLocator, FixedLocator
import seaborn as sns
from utils.io_utils import load, read_fasta
from utils.data_utils import seq2effect
# EVMutation imports
from couplings_model import CouplingsModel
def add_unirep_model(df, model_names, path, name):
# The inference for UniRep has sorted sequences.
df = df.sort_values('seq')
seqlen = len(df.seq.values[0])
df[name] = -seqlen * load(path)
model_names.append(name)
return df, model_names
def add_ev_model(df, model_names, path, name, dataset, include_indep=False):
wt = read_fasta(os.path.join('../data', dataset, 'wt.fasta'))[0]
couplings_model = CouplingsModel(path)
df[f'{name}'] = seq2effect(df.seq.values, couplings_model)
model_names.append(name)
if include_indep:
indep_model = couplings_model.to_independent_model()
df[f'{name}_indep'] = seq2effect(df.seq.values, wt, indep_model)
model_names.append(f'{name}_indep')
return df, model_names
def add_hmm_model(df, model_names, path, name, dataset):
df = df.sort_values('seq')
records = SeqIO.parse(os.path.join('../data', dataset, 'seqs.fasta'),
'fasta')
ids = []
seqs = []
for rec in records:
seqs.append(str(rec.seq))
ids.append(str(rec.id))
id2seq = pd.Series(index=ids, data=seqs, name='seq')
hmm_ll = pd.read_csv(path)[['target', 'score_full']]
hmm_ll = hmm_ll.join(id2seq, on='target', how='left')
hmm_ll = hmm_ll.drop_duplicates(subset='seq')
df[name] = hmm_ll.sort_values('seq')['score_full'].values
model_names.append(name)
return df, model_names
metric_display_name = {
'ndcg': 'NDCG',
'topk_mean': 'Top 96 mean',
'spearman': 'Spearman correlation',
}
def retrieve_metric(df, metric_name, n_mut=None, predictor=None):
tmp = df
if predictor is not None:
if isinstance(predictor, str):
predictor = [predictor]
tmp = tmp.loc[tmp.predictor.apply(lambda x: x in predictor)]
if n_mut is not None:
metric_name = f'{metric_name}_{n_mut}mut'
tmp = tmp[['predictor', 'n_train', metric_name]]
return tmp.rename(columns={metric_name:'val'})
def metric_lineplot(df, predictors, metric, predictor_names, dataset_name,
max_n_mut, savename='figure', legend=None, mutcounts=None, **kwargs):
fig, axes = plt.subplots(1, max_n_mut+1,
figsize=((max_n_mut+1)*3, 4),
sharex=True, sharey=True)
ax = axes[0]
nmut_to_title = {
1: 'Single mutants',
2: 'Double mutants',
3: 'Triple mutants',
4: 'Quadruple mutants',
}
nmut_to_title.update({i: f'{i}th-order Mutants' for i in range(5, 11)})
tmp = retrieve_metric(df, metric, n_mut=None, predictor=predictors)
sns.lineplot(data=tmp, x='n_train', y='val',
hue='predictor', style='predictor', ax=ax,
hue_order=predictors, style_order=predictors, **kwargs)
#ax.hlines(df[df.predictor == 'mutation'].mean().spearman, 48, 240, color='dimgrey')
ax.set_title(f'mutants of all orders')
ax.set_ylabel(metric_display_name[metric])
ax.set_xlabel('Training data size')
for n_mut in range(1, max_n_mut+1):
ax = axes[n_mut]
tmp = retrieve_metric(df, metric, n_mut=n_mut, predictor=predictors)
sns.lineplot(data=tmp, x='n_train', y='val',
hue='predictor', style='predictor', ax=ax,
hue_order=predictors, style_order=predictors, **kwargs)
ax.set_title(nmut_to_title[n_mut])
ax.set_ylabel(metric_display_name[metric])
ax.set_xlabel('Training data size')
if mutcounts is not None:
for i in range(max_n_mut+1):
axes[i].annotate(f'Data size: {int(mutcounts[i])}',
xy=(0.29, 0.03), xycoords='axes fraction',
fontsize=9)
if legend is not None:
handles, labels = legend['handles'], legend['labels']
lgd = fig.legend(handles, labels, bbox_to_anchor=legend['loc'],
loc='upper left', ncol=1, fontsize=11, frameon=False)
ax.xaxis.set_minor_locator(MultipleLocator(24))
ax.xaxis.set_major_locator(MultipleLocator(48))
ax.xaxis.set_major_formatter('{x:.0f}')
for ax in axes:
ax.get_legend().remove()
pad = 8
ax = axes[0]
annot = ax.annotate(dataset_name, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
xycoords=ax.yaxis.label, textcoords='offset points',
size='large', ha='right', va='center', rotation=90, fontsize=14)
ax.annotate('Test on:', xy=(-0.05, 0.492), xycoords=ax.title, textcoords='offset points',
size='large', ha='right', va='center', rotation=0, fontsize=13)
plt.subplots_adjust(top=0.80, wspace=0.1)
if legend is not None:
plt.savefig('../figs/' + savename + '.png', format='png', dpi=600,
bbox_extra_artists=(annot,lgd,), bbox_inches='tight', pad_inches=0)
else:
plt.savefig('../figs/' + savename + '.png', format='png', dpi=600,
bbox_inches='tight', pad_inches=0)
plt.show()