|
import os |
|
from Bio import SeqIO |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
from matplotlib.ticker import MultipleLocator, FixedLocator |
|
import seaborn as sns |
|
|
|
from utils.io_utils import load, read_fasta |
|
from utils.data_utils import seq2effect |
|
|
|
|
|
from couplings_model import CouplingsModel |
|
|
|
|
|
def add_unirep_model(df, model_names, path, name): |
|
|
|
df = df.sort_values('seq') |
|
seqlen = len(df.seq.values[0]) |
|
df[name] = -seqlen * load(path) |
|
model_names.append(name) |
|
return df, model_names |
|
|
|
|
|
def add_ev_model(df, model_names, path, name, dataset, include_indep=False): |
|
wt = read_fasta(os.path.join('../data', dataset, 'wt.fasta'))[0] |
|
|
|
couplings_model = CouplingsModel(path) |
|
df[f'{name}'] = seq2effect(df.seq.values, couplings_model) |
|
model_names.append(name) |
|
|
|
if include_indep: |
|
indep_model = couplings_model.to_independent_model() |
|
df[f'{name}_indep'] = seq2effect(df.seq.values, wt, indep_model) |
|
model_names.append(f'{name}_indep') |
|
return df, model_names |
|
|
|
|
|
def add_hmm_model(df, model_names, path, name, dataset): |
|
df = df.sort_values('seq') |
|
records = SeqIO.parse(os.path.join('../data', dataset, 'seqs.fasta'), |
|
'fasta') |
|
ids = [] |
|
seqs = [] |
|
for rec in records: |
|
seqs.append(str(rec.seq)) |
|
ids.append(str(rec.id)) |
|
id2seq = pd.Series(index=ids, data=seqs, name='seq') |
|
hmm_ll = pd.read_csv(path)[['target', 'score_full']] |
|
hmm_ll = hmm_ll.join(id2seq, on='target', how='left') |
|
hmm_ll = hmm_ll.drop_duplicates(subset='seq') |
|
df[name] = hmm_ll.sort_values('seq')['score_full'].values |
|
model_names.append(name) |
|
return df, model_names |
|
|
|
|
|
metric_display_name = { |
|
'ndcg': 'NDCG', |
|
'topk_mean': 'Top 96 mean', |
|
'spearman': 'Spearman correlation', |
|
} |
|
|
|
|
|
def retrieve_metric(df, metric_name, n_mut=None, predictor=None): |
|
tmp = df |
|
if predictor is not None: |
|
if isinstance(predictor, str): |
|
predictor = [predictor] |
|
tmp = tmp.loc[tmp.predictor.apply(lambda x: x in predictor)] |
|
if n_mut is not None: |
|
metric_name = f'{metric_name}_{n_mut}mut' |
|
tmp = tmp[['predictor', 'n_train', metric_name]] |
|
return tmp.rename(columns={metric_name:'val'}) |
|
|
|
|
|
def metric_lineplot(df, predictors, metric, predictor_names, dataset_name, |
|
max_n_mut, savename='figure', legend=None, mutcounts=None, **kwargs): |
|
fig, axes = plt.subplots(1, max_n_mut+1, |
|
figsize=((max_n_mut+1)*3, 4), |
|
sharex=True, sharey=True) |
|
ax = axes[0] |
|
nmut_to_title = { |
|
1: 'Single mutants', |
|
2: 'Double mutants', |
|
3: 'Triple mutants', |
|
4: 'Quadruple mutants', |
|
} |
|
nmut_to_title.update({i: f'{i}th-order Mutants' for i in range(5, 11)}) |
|
tmp = retrieve_metric(df, metric, n_mut=None, predictor=predictors) |
|
sns.lineplot(data=tmp, x='n_train', y='val', |
|
hue='predictor', style='predictor', ax=ax, |
|
hue_order=predictors, style_order=predictors, **kwargs) |
|
|
|
ax.set_title(f'mutants of all orders') |
|
ax.set_ylabel(metric_display_name[metric]) |
|
ax.set_xlabel('Training data size') |
|
for n_mut in range(1, max_n_mut+1): |
|
ax = axes[n_mut] |
|
tmp = retrieve_metric(df, metric, n_mut=n_mut, predictor=predictors) |
|
sns.lineplot(data=tmp, x='n_train', y='val', |
|
hue='predictor', style='predictor', ax=ax, |
|
hue_order=predictors, style_order=predictors, **kwargs) |
|
ax.set_title(nmut_to_title[n_mut]) |
|
ax.set_ylabel(metric_display_name[metric]) |
|
ax.set_xlabel('Training data size') |
|
if mutcounts is not None: |
|
for i in range(max_n_mut+1): |
|
axes[i].annotate(f'Data size: {int(mutcounts[i])}', |
|
xy=(0.29, 0.03), xycoords='axes fraction', |
|
fontsize=9) |
|
|
|
if legend is not None: |
|
handles, labels = legend['handles'], legend['labels'] |
|
lgd = fig.legend(handles, labels, bbox_to_anchor=legend['loc'], |
|
loc='upper left', ncol=1, fontsize=11, frameon=False) |
|
|
|
ax.xaxis.set_minor_locator(MultipleLocator(24)) |
|
ax.xaxis.set_major_locator(MultipleLocator(48)) |
|
ax.xaxis.set_major_formatter('{x:.0f}') |
|
|
|
for ax in axes: |
|
ax.get_legend().remove() |
|
|
|
pad = 8 |
|
ax = axes[0] |
|
annot = ax.annotate(dataset_name, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0), |
|
xycoords=ax.yaxis.label, textcoords='offset points', |
|
size='large', ha='right', va='center', rotation=90, fontsize=14) |
|
ax.annotate('Test on:', xy=(-0.05, 0.492), xycoords=ax.title, textcoords='offset points', |
|
size='large', ha='right', va='center', rotation=0, fontsize=13) |
|
|
|
plt.subplots_adjust(top=0.80, wspace=0.1) |
|
if legend is not None: |
|
plt.savefig('../figs/' + savename + '.png', format='png', dpi=600, |
|
bbox_extra_artists=(annot,lgd,), bbox_inches='tight', pad_inches=0) |
|
else: |
|
plt.savefig('../figs/' + savename + '.png', format='png', dpi=600, |
|
bbox_inches='tight', pad_inches=0) |
|
plt.show() |
|
|
|
|
|
|