|
|
|
|
|
from . import logger |
|
import matplotlib as mpl |
|
mpl.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import re |
|
import os |
|
import warnings |
|
import json |
|
from os.path import join |
|
from collections import OrderedDict |
|
|
|
_name = "postProc" |
|
|
|
|
|
def params(): |
|
""" |
|
Sets some global parameters for the plots |
|
:return: None |
|
""" |
|
plt.rcParams['axes.facecolor'] = 'f5f5f5' |
|
plt.rcParams['axes.edgecolor'] = '0.45' |
|
plt.rcParams['axes.axisbelow'] = True |
|
plt.rcParams['axes.labelcolor'] = '0.45' |
|
plt.rcParams['text.color'] = '0.45' |
|
plt.rcParams['xtick.color'] = '0.45' |
|
plt.rcParams['ytick.color'] = '0.45' |
|
plt.rcParams['xtick.major.pad'] = 4 |
|
plt.rcParams['ytick.major.pad'] = 5 |
|
plt.rcParams['xtick.major.width'] = 1 |
|
plt.rcParams['ytick.major.width'] = 1 |
|
|
|
|
|
def get_scores(out_file): |
|
""" |
|
|
|
:param out_file: path to a csv file generated by Aggrescan3D |
|
:return: dictionary - {chainID:[residue ID, agg3d score, residue label]}, |
|
dictionary - {chainID:{"min_value":val,"max_value":val, |
|
"total_value":val,"avg_value":val}} |
|
""" |
|
|
|
pattern = re.compile(r"^(.*),(.*),(.*),(.*),(.*)$", re.M) |
|
with open(out_file, 'r') as f: |
|
data = pattern.findall(f.read().replace("\r", ""))[1:] |
|
|
|
chains = set([i[1] for i in data]) |
|
chains.add("All") |
|
dat = OrderedDict() |
|
scores = OrderedDict() |
|
stats = OrderedDict() |
|
resNumber = 1 |
|
for chain_id in chains: |
|
dat[chain_id] = [] |
|
scores[chain_id] = [] |
|
for line in data: |
|
if len(line) != 5: |
|
continue |
|
|
|
chain = line[1] |
|
label = line[3] + line[2] |
|
aggScore = float(line[4]) |
|
scores[chain].append(aggScore) |
|
scores["All"].append(aggScore) |
|
if abs(aggScore) > 1e-10: |
|
dat[chain].append((resNumber, aggScore, label)) |
|
dat["All"].append((resNumber, aggScore, label)) |
|
|
|
resNumber += 1 |
|
|
|
for chain in chains: |
|
min3d = min(scores[chain]) |
|
max3d = max(scores[chain]) |
|
sum3d = np.sum(scores[chain]) |
|
avg3d = np.round(sum3d / len(scores[chain]), decimals=4) |
|
stats[chain] = {"min_value": min3d, "max_value": max3d, "total_value": sum3d, "avg_value": avg3d} |
|
return dat, stats |
|
|
|
|
|
def make_plots(data=None, work_dir="", get_figure=False): |
|
""" |
|
Creates png and svg plots of Aggrescan3D scores for a single chain |
|
:param data: dictionary - {chainID:[residue ID, agg3d score, residue label]} |
|
:param work_dir: directory where the plots will be saved |
|
:param get_figure: if set to True, will return the figure |
|
:return: None or plt.figure object |
|
""" |
|
|
|
warnings.simplefilter("ignore") |
|
for chain in list(data.keys()): |
|
if chain != "All": |
|
dat = data[chain] |
|
|
|
params() |
|
fig = plt.figure(figsize=(10, 6.6)) |
|
x = np.array([l[0] for l in dat]) |
|
y = np.array([l[1] for l in dat]) |
|
l = np.array([l[2] for l in dat]) |
|
|
|
plt.xlabel("Residue") |
|
plt.ylabel("Score") |
|
plt.axhline(linewidth=1, color='0.45', linestyle='--') |
|
plt.xticks(x[1::10], l[1::10], rotation=35, fontsize='small') |
|
plt.title("A3D profile | chain " + chain) |
|
plt.axis(ymin=-4, ymax=4, xmin=min(x) - 2, xmax=max(x) + 2) |
|
plt.plot(x, y, linewidth=1.5, alpha=0.75, marker='o', mec='None') |
|
plt.grid(alpha=0.5, color='0.9', linewidth=1, linestyle='--') |
|
|
|
for x, y, l in zip(x, y, l): |
|
if float(y) > 0.0: |
|
plt.annotate(l, xy=(x, y), xytext=(1, 1), alpha=0.5, fontsize='small', gid="label_" + str(x), |
|
textcoords='offset points') |
|
logger.log_file(module_name=_name,msg="Saving plots as %s.png and %s.svg" % (chain, chain)) |
|
plt.savefig(os.path.join(work_dir, "%s.png" % chain), format="png") |
|
plt.savefig(os.path.join(work_dir, "%s.svg" % chain), format="svg") |
|
if get_figure: return fig |
|
|
|
|
|
def make_auto_mut_plot(work_dir=""): |
|
""" |
|
Create a collective plot of mutants and the wild type, this is mostly a copy paste from server plot into mpl |
|
#TODO actually use fig axes object rather than plt like that |
|
""" |
|
_target_mutations = ["E", "K", "D", "R"] |
|
warnings.simplefilter("ignore") |
|
mutants = [] |
|
with open(join(work_dir, "Mutations_summary.csv"), "r") as f: |
|
f.readline() |
|
for line in f: |
|
mutants.append(line.split(",")[0]) |
|
|
|
with open(join(work_dir, "A3D.csv"), 'r') as f: |
|
f.readline() |
|
wild_labels, wild_y = [], [] |
|
for line in f: |
|
a = line.strip().split(',') |
|
|
|
wild_labels.append(("Chain %s" % a[1], a[2] + a[3])) |
|
wild_y.append(float(a[-1])) |
|
wild_x = [i for i in range(len(wild_y))] |
|
while mutants: |
|
data = OrderedDict() |
|
one_r_mutated = [] |
|
mutated = mutants[0][0] |
|
for mutant in mutants[:]: |
|
if mutant[0] == mutated: |
|
one_r_mutated.append(mutants.pop(mutants.index(mutant))) |
|
data["Wild_type"] = [wild_x, wild_y, wild_labels] |
|
for mutant in one_r_mutated: |
|
with open(join(work_dir, mutant + ".csv"), 'r') as f: |
|
f.readline() |
|
labels, y = [], [] |
|
for line in f: |
|
a = line.strip().split(',') |
|
|
|
labels.append(("Chain %s" % a[1], a[2]+a[3])) |
|
y.append(float(a[-1])) |
|
x = [i for i in range(len(y))] |
|
data[mutant] = [x, y, labels] |
|
_plot(data, work_dir, filename="%s_mutants" % mutant[2:]) |
|
|
|
|
|
def _plot(data, work_dir, filename): |
|
params() |
|
fig = plt.figure(figsize=(10, 6.6)) |
|
plt.ylabel("Score") |
|
plt.axhline(linewidth=1, color='0.45', linestyle='--') |
|
plt.title("A3D mutations profile") |
|
plt.grid(alpha=0.5, color='0.9', linewidth=1, linestyle='--') |
|
for key, value in data.items(): |
|
x, y, labels = value |
|
plt.plot(x, y, label=key, linewidth=1.5, alpha=0.75, marker='o', mec='None') |
|
plt.xticks(x[1::10], labels[1::10], rotation=35, fontsize='small') |
|
logger.log_file(module_name=_name, msg="Saving auto mutation plots to %s (svg and png)" % filename) |
|
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) |
|
plt.savefig(os.path.join(work_dir, filename + ".png"), format="png", bbox_inches='tight') |
|
plt.savefig(os.path.join(work_dir, filename + ".svg"), format="svg", bbox_inches='tight') |
|
plt.close() |
|
|
|
|
|
def append_scores(a3d="", in_pdb="", out_pdb="", work_dir=""): |
|
""" |
|
Replaces the last field in pdb file with Aggrescan3D score |
|
:param a3d: filepath to a csv aggrescan-formatted file with scores |
|
:param in_pdb: filepath to a input pdb file |
|
:param out_pdb: fielapth to which the output will be written |
|
:param work_dir: Output directory |
|
:return: None |
|
""" |
|
rec = re.compile(r"^(.*),(.*),(.*),(.*),(.*)$", re.M) |
|
amino_a_dict ={'A': 'ALA', 'R': 'ARG','N': 'ASN','D': 'ASP','C': 'CYS','E': 'GLU', |
|
'Q': 'GLN','G': 'GLY','H': 'HIS','I': 'ILE','L': 'LEU','K': 'LYS', |
|
'M': 'MET','F': 'PHE','P': 'PRO','S': 'SER','T':'THR','W': 'TRP', |
|
'Y': 'TYR', 'V': 'VAL','X': 'UNK'} |
|
|
|
with open(a3d, "r") as agg_out_fh, open(in_pdb, "r") as p: |
|
block = p.read() |
|
d = rec.findall(agg_out_fh.read().replace("\r", ""))[1:] |
|
for r in d: |
|
amino_acid = amino_a_dict[r[3]] |
|
agg_score = "%6.2f" % (float(r[4])) |
|
res_details = "%3s %1s%4s" % (amino_acid, r[1], r[2]) |
|
block = re.sub(r'(?<=^ATOM.{13}'+res_details+'.{34})(.*)$', agg_score, block, flags=re.M) |
|
logger.to_file(filename=os.path.join(work_dir, out_pdb), content=block) |
|
|
|
|
|
def save_stats(data="", work_dir="", output="statistics"): |
|
""" |
|
Saves statistcs of the Aggrescan3D scores calculations |
|
:param data: string - JSON generated by get_scores, formatted like: |
|
{chainID:{"min_value":val,"max_value":val,"total_value":val,"avg_value":val}} |
|
:param work_dir: Output directory |
|
:param output: outputs filename |
|
:return: None |
|
""" |
|
logger.to_file(filename=os.path.join(work_dir, output), content=data) |
|
|
|
|
|
def prepare_output(work_dir="", final=True, model_name="", scores_to_pdb=False, get_data=False): |
|
""" |
|
Calls make_plots, save_stats and append_scores, see details there |
|
:param work_dir: Working dir of Aggrescan3D run |
|
:param final: [bool] if True data is plotted and output.pdb generated to work_dir |
|
:param model_name: [string] filename of the currently analyzed pdb file (without the .pdb part) |
|
:param scores_to_pdb: [bool] Decide if a3d score should be pun in the file's bfactor place |
|
:return: dictionary - {chainID: {"min_value" : val, "max_value" : val, |
|
"total_value" : val, "avg_value" : val}} |
|
""" |
|
data, stats = get_scores(os.path.join(work_dir, "A3D.csv")) |
|
save_stats(data=json.dumps(stats), output=model_name + "_stats", work_dir=work_dir) |
|
if scores_to_pdb: |
|
append_scores(a3d=os.path.join(work_dir, "A3D.csv"), in_pdb=model_name + ".pdb", |
|
out_pdb=model_name+".pdb", work_dir=work_dir) |
|
if final: |
|
make_plots(data=data, work_dir=work_dir) |
|
append_scores(a3d=os.path.join(work_dir, "A3D.csv"), in_pdb=os.path.join(work_dir, "folded.pdb"), |
|
out_pdb=os.path.join(work_dir, "output.pdb"), work_dir=work_dir) |
|
if get_data: |
|
return data, stats |
|
return stats |
|
|