FoldMark / protenix /web_service /prediction_visualization.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
import biotite
import biotite.structure as struc
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol
from IPython.display import HTML
from biotite.structure.io.pdbx import CIFFile, get_structure
from matplotlib.colors import ListedColormap
from matplotlib.ticker import MaxNLocator
class PredictionLoader:
def __init__(self, pred_fpath: str):
self.pred_fpath = pred_fpath
self._load_cif()
self._load_confidence_pred()
def _load_json(self, fpath: str):
try:
with open(fpath, "r") as file:
data = json.load(file)
return data
except FileNotFoundError:
print(f"Error: File '{fpath}' not found.")
except json.JSONDecodeError:
print(f"Error: Failed to decode JSON from file '{fpath}'.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
def _convert_to_numpy(self, data: dict):
for k, v in data.items():
if isinstance(v, list):
# convert values to numpy array
data[k] = np.array(v)
return data
def _load_cif(self):
assert os.path.exists(self.pred_fpath), "prediction file path does not exist."
fpath_all_preds = glob.glob(os.path.join(self.pred_fpath, "*.cif"))
self.cif_paths = fpath_all_preds
self.fnames = [s.split("/")[-1].replace(".cif", "") for s in fpath_all_preds]
self.preds = [
get_structure(pdbx_file=CIFFile.read(fpath), model=1, altloc="all")
for fpath in fpath_all_preds
]
def _load_confidence_pred(self):
fpath_all_confidences = glob.glob(os.path.join(self.pred_fpath, "*.json"))
fpath_full_confidences = [
fpath for fpath in fpath_all_confidences if "full_data_sample" in fpath
]
fpath_summary_confidences = [
fpath for fpath in fpath_all_confidences if "summary_confidence" in fpath
]
self.summary_confidence_data = [
self._convert_to_numpy(self._load_json(fpath))
for fpath in fpath_summary_confidences
]
if fpath_full_confidences:
self.full_confidence_data = [
self._convert_to_numpy(self._load_json(fpath))
for fpath in fpath_full_confidences
]
else:
self.full_confidence_data = None
def plot_contact_maps_from_pred(
preds: list,
fnames: list,
threshold: float = 7,
rep_atom: str = "CA",
) -> None:
"""
Plot contact maps from a directory
Args
preds (list): list of biotite.structure.AtomArrays
fnames (list): list of prediction names
threshold (int): threshold cutoff for the contact map
rep_atom (str): type of representative atoms, commonly options are 'CA' and 'CB'
"""
adjacency_matrices = []
for pred in preds:
rep_atom_coord = pred[pred.atom_name == rep_atom]
cell_list = struc.CellList(rep_atom_coord, cell_size=threshold)
adjacency_matrices.append(cell_list.create_adjacency_matrix(threshold))
cmap = ListedColormap(["white", biotite.colors["dimgreen"]])
fig, axes = plt.subplots(nrows=1, ncols=len(fnames), figsize=(len(fnames) * 3, 6))
for i in range(len(fnames)):
axes[i].matshow(adjacency_matrices[i], cmap=cmap, origin="lower")
axes[i].xaxis.tick_bottom()
axes[i].set_aspect("equal")
axes[i].set_xlabel("Residue Number")
axes[i].set_ylabel("Residue Number")
axes[i].set_title(fnames[i])
fig.tight_layout()
plt.show()
def plot_confidence_measures_from_pred(
full_confidence: list,
summary_confidence: list,
fnames: list,
show_global_confidence: bool,
*args,
**kwargs,
) -> None:
"""
Plot contact maps from a directory
Args
full_confidence (list): list of full confidence metrics from prediction
summary_confidence (list): list of summary confidence metrics from prediction
fnames (list): list of prediction names
"""
atom_plddts = [d["atom_plddt"] for d in full_confidence]
token_pdes = [d["token_pair_pde"] for d in full_confidence]
token_paes = [d["token_pair_pae"] for d in full_confidence]
summary_keys = ["plddt", "gpde", "ptm", "iptm"]
fig, axes = plt.subplots(
nrows=3,
ncols=len(fnames),
figsize=(len(fnames) * 5, 10),
gridspec_kw={"height_ratios": [1.5, 2, 2]},
)
for i in range(3):
for j in range(len(fnames)):
summary_text = ", ".join(
[
f"{k}:{v:.4f}"
for k, v in summary_confidence[j].items()
if k in summary_keys
]
)
if i == 0:
axes[i, j].plot(atom_plddts[j], color="k")
axes[i, j].set_title(fnames[j], fontsize=15, pad=20)
if show_global_confidence:
axes[i, j].text(
0.5,
1.07,
summary_text,
ha="center",
va="center",
transform=axes[i, j].transAxes,
fontsize=10,
)
axes[i, j].set_xlabel("Atom ID", fontsize=12)
axes[i, j].set_ylabel("pLDDT", fontsize=12)
axes[i, j].set_ylim([0, 100])
axes[i, j].spines[["right", "top"]].set_visible(False)
else:
data_to_plot = token_pdes[j] if i == 1 else token_paes[j]
cax = axes[i, j].matshow(data_to_plot, origin="lower")
axes[i, j].xaxis.tick_bottom()
axes[i, j].set_aspect("equal")
axes[i, j].set_xlabel("Scored Residue", fontsize=12)
axes[i, j].set_ylabel("Aligned Residue", fontsize=12)
axes[i, j].xaxis.set_major_locator(
MaxNLocator(3)
) # Max 5 ticks on the x-axis
axes[i, j].yaxis.set_major_locator(
MaxNLocator(3)
) # Max 5 ticks on the y-axis
color_bar = fig.colorbar(
cax, ax=axes[i, j], orientation="vertical", pad=0.1, shrink=0.6
)
cbar_label = (
"Predicted Distance Error" if i == 1 else "Predicted Aligned Error"
)
color_bar.set_label(cbar_label)
fig.tight_layout()
fig.subplots_adjust(hspace=0.4)
plt.show()
def plot_3d(
pred_id=0,
pred_loader=None,
show_sidechains=False,
show_mainchains=False,
color="rainbow",
):
view = py3Dmol.view(
js="https://3dmol.org/build/3Dmol.js",
)
fpath = pred_loader.cif_paths[pred_id]
view.addModelsAsFrames(open(fpath, "r").read(), "cif")
if color == "pLDDT":
assert pred_loader is not None
plddt = pred_loader.full_confidence_data[0]["atom_plddt"]
for i, score in enumerate(plddt):
normalized_color = int(
255 * (score - min(plddt)) / (max(plddt) - min(plddt))
)
color = f"rgb({normalized_color}, 0, {255 - normalized_color})" # Gradient from blue to red
# Apply color to each atom individually based on pLDDT score
view.setStyle(
{"serial": i + 1}, {"cartoon": {"color": color, "min": 50, "max": 90}}
)
elif color == "rainbow":
view.setStyle({"cartoon": {"color": "spectrum"}})
if show_sidechains:
BB = ["C", "O", "N"]
view.addStyle(
{
"and": [
{"resn": ["GLY", "PRO"], "invert": True},
{"atom": BB, "invert": True},
]
},
{"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
)
view.addStyle(
{"and": [{"resn": "GLY"}, {"atom": "CA"}]},
{"sphere": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
)
view.addStyle(
{"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]},
{"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
)
if show_mainchains:
BB = ["C", "O", "N", "CA"]
view.addStyle(
{"atom": BB}, {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}}
)
view.zoomTo()
return HTML(view._make_html())
def plot_best_confidence_measure(prediction_fpath: str, output_dir: str = None):
"""
Plot confidence_measure from a directory
Args
full_confidence (list): list of full confidence metrics from prediction
summary_confidence (list): list of summary confidence metrics from prediction
fnames (list): list of prediction names
"""
if output_dir is None or output_dir == "":
output_dir = prediction_fpath
pred_loader = PredictionLoader(prediction_fpath)
if (
pred_loader.full_confidence_data is None
or len(pred_loader.full_confidence_data) < 1
):
print("no full confidence data found, skip ploting.")
return
full_confidence = pred_loader.full_confidence_data[:1]
summary_confidence = pred_loader.summary_confidence_data[:1]
atom_plddts = [d["atom_plddt"] for d in full_confidence]
token_pdes = [d["token_pair_pde"] for d in full_confidence]
token_paes = [d["token_pair_pae"] for d in full_confidence]
summary_keys = ["plddt", "gpde", "ptm", "iptm"]
fig, axes = plt.subplots(
nrows=1,
ncols=3,
figsize=(20, 10),
gridspec_kw={"width_ratios": [2.5, 2, 2]},
)
for i in range(3):
summary_text = ", ".join(
[
f"{k}: {v:.4f}"
for k, v in summary_confidence[0].items()
if k in summary_keys
]
)
if i == 0:
axes[i].plot(atom_plddts[0], color="k")
axes[i].text(
0.5,
1.07,
summary_text,
ha="center",
va="center",
transform=axes[i].transAxes,
fontsize=10,
)
axes[i].set_xlabel("Atom ID", fontsize=12)
axes[i].set_ylabel("pLDDT", fontsize=12)
axes[i].set_ylim([0, 1])
axes[i].spines[["right", "top"]].set_visible(False)
else:
data_to_plot = token_pdes[0] if i == 1 else token_paes[0]
cax = axes[i].matshow(data_to_plot, origin="lower")
axes[i].xaxis.tick_bottom()
axes[i].set_aspect("equal")
axes[i].set_xlabel("Scored Residue", fontsize=12)
axes[i].set_ylabel("Aligned Residue", fontsize=12)
axes[i].xaxis.set_major_locator(MaxNLocator(3)) # Max 5 ticks on the x-axis
axes[i].yaxis.set_major_locator(MaxNLocator(3)) # Max 5 ticks on the y-axis
color_bar = fig.colorbar(
cax, ax=axes[i], orientation="vertical", pad=0.1, shrink=0.6
)
cbar_label = (
"Predicted Distance Error" if i == 1 else "Predicted Aligned Error"
)
color_bar.set_label(cbar_label)
out_png_fpath = os.path.join(output_dir, "best_sample_confidence.png")
plt.savefig(out_png_fpath)
print(f"save best sample confidence plot successfully to {out_png_fpath}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--prediction_fpath",
type=str,
required=True,
help="The prediction files to parse",
)
parser.add_argument(
"--output_dir",
type=str,
required=False,
default=None,
help="ouput dir to save plot",
)
args = parser.parse_args()
plot_best_confidence_measure(args.prediction_fpath, args.output_dir)
#pred_loader = PredictionLoader(args.prediction_fpath)
#plot_3d(pred_id=0, pred_loader=pred_loader)