zsp / data.py
Massimo G. Totaro
table fix
4f4563b
raw
history blame
8.52 kB
from math import ceil
import matplotlib.pyplot as plt
import pandas as pd
from re import match
import seaborn as sns
from model import Model
class Data:
"""Container for input and output data"""
# Initialise empty model as static class member for efficiency
model = Model()
def parse_seq(self, src: str):
"""Parse input sequence"""
self.seq = src.strip().upper().replace('\n', '')
if not all(x in self.model.alphabet for x in self.seq):
raise RuntimeError("Unrecognised characters in sequence")
def parse_sub(self, trg: str):
"""Parse input substitutions"""
self.mode = None
self.sub = list()
self.trg = trg.strip().upper().split()
self.resi = list()
# Identify running mode
if len(self.trg) == 1 and len(self.trg[0]) == len(self.seq) and match(r'^\w+$', self.trg[0]):
# If single string of same length as sequence, seq vs seq mode
self.mode = 'MUT'
for resi, (src, trg) in enumerate(zip(self.seq, self.trg[0]), 1):
if src != trg:
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
else:
if all(match(r'\d+', x) for x in self.trg):
# If all strings are numbers, deep mutational scanning mode
self.mode = 'DMS'
for resi in map(int, self.trg):
src = self.seq[resi-1]
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
# If all strings are of the form X#Y, single substitution mode
self.mode = 'MUT'
self.sub = self.trg
self.resi = [int(x[1:-1]) for x in self.trg]
for s, *resi, _ in self.trg:
if self.seq[int(''.join(resi))-1] != s:
raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
else:
self.mode = 'TMS'
for resi, src in enumerate(self.seq, 1):
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
self.sub = pd.DataFrame(self.sub, columns=['0'])
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
"initialise data"
# if model has changed, load new model
if self.model.model_name != model_name:
self.model_name = model_name
self.model = Model(model_name)
self.parse_seq(src)
self.offset = 0
self.parse_sub(trg)
self.scoring_strategy = scoring_strategy
self.token_probs = None
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
self.out_img = f'{out_file}.png'
self.out_csv = f'{out_file}.csv'
def parse_output(self) -> None:
"format output data for visualisation"
if self.mode == 'TMS':
self.process_tms_mode()
self.out.to_csv(self.out_csv, float_format='%.2f')
else:
if self.mode == 'DMS':
self.sort_by_residue_and_score()
elif self.mode == 'MUT':
self.sort_by_score()
else:
raise RuntimeError(f"Unrecognised mode {self.mode}")
self.out.columns = [str(i) for i in range(self.out.shape[1])]
self.out_img = (self.out.style
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
.hide(axis=0)
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8))
self.out.to_csv(self.out_csv, float_format='%.2f', index=False, header=False)
def sort_by_score(self):
self.out = self.out.sort_values(self.model_name, ascending=False)
def sort_by_residue_and_score(self):
self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
.sort_values(['resi', self.model_name], ascending=[True,False])
.groupby(['resi'])
.head(19)
.drop(['resi'], axis=1))
self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
def process_tms_mode(self):
self.out = self.assign_resi_and_group()
self.out = self.concat_and_set_axis()
self.out /= self.out.abs().max().max()
divs = self.calculate_divs()
ncols = min(divs, key=lambda x: abs(x-60))
nrows = ceil(self.out.shape[1]/ncols)
ncols = self.adjust_ncols(ncols, nrows)
self.plot_heatmap(ncols, nrows)
def assign_resi_and_group(self):
return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
.groupby(['resi'])
.head(19))
def concat_and_set_axis(self):
return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
.pipe(self.create_dataframe)
.sort_values(['0'], ascending=[True])
.drop(['resi', '0'], axis=1)
.set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
.astype(float)
) for x in range(self.out.shape[0]//19)]
, axis=1)
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns'))
def create_dataframe(self, df):
return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True)
def calculate_divs(self):
return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60]
def adjust_ncols(self, ncols, nrows):
while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]:
ncols -= 1
return ncols + 1
def plot_heatmap(self, ncols, nrows):
if nrows < 2:
self.plot_single_heatmap()
else:
self.plot_multiple_heatmaps(ncols, nrows)
plt.savefig(self.out_img, format='png', dpi=300)
def plot_single_heatmap(self):
fig = plt.figure(figsize=(12, 6))
sns.heatmap(self.out
, cmap='RdBu'
, cbar=False
, square=True
, xticklabels=1
, yticklabels=1
, center=0
, annot=self.out.map(lambda x: ' ' if x != 0 else '·')
, fmt='s'
, annot_kws={'size': 'xx-large'})
fig.tight_layout()
def plot_multiple_heatmaps(self, ncols, nrows):
fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows))
for i in range(nrows):
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
label = tmp.map(lambda x: ' ' if x != 0 else '·')
sns.heatmap(tmp
, ax=ax[i]
, cmap='RdBu'
, cbar=False
, square=True
, xticklabels=1
, yticklabels=1
, center=0
, annot=label
, fmt='s'
, annot_kws={'size': 'xx-large'})
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
fig.tight_layout()
def calculate(self):
"run model and parse output"
self.model.run_model(self)
self.parse_output()
return self
def csv(self):
"return output data"
return self.out_csv
def image(self):
"return output data"
return self.out_img