protpardelle / evaluation.py
Simon Duerr
webapp
8c639ec
raw
history blame
13.8 kB
"""
https://github.com/ProteinDesignLab/protpardelle
License: MIT
Author: Alex Chu
Utils for computing evaluation metrics.
"""
import argparse
import os
import warnings
from typing import Tuple
from Bio.Align import substitution_matrices
import numpy as np
import torch
from transformers import AutoTokenizer, EsmForProteinFolding
from torchtyping import TensorType
from core import residue_constants
from core import utils
from core import protein_mpnn as mpnn
import modules
import sampling
def mean(x):
if len(x) == 0:
return 0
return sum(x) / len(x)
def calculate_seq_identity(seq1, seq2, seq_mask=None):
identity = (seq1 == seq2.to(seq1)).float()
if seq_mask is not None:
identity *= seq_mask.to(seq1)
return identity.sum(-1) / seq_mask.to(seq1).sum(-1).clamp(min=1)
else:
return identity.mean(-1)
def design_sequence(coords, model=None, num_seqs=1, disallow_aas=["C"]):
# Returns list of strs; seqs like 'MKRLLDS', not aatypes
if model is None:
model = mpnn.get_mpnn_model()
if isinstance(coords, str):
temp_pdb = False
pdb_fn = coords
else:
temp_pdb = True
pdb_fn = f"tmp{np.random.randint(0, 1e8)}.pdb"
gly_idx = residue_constants.restype_order["G"]
gly_aatype = (torch.ones(coords.shape[0]) * gly_idx).long()
utils.write_coords_to_pdb(coords, pdb_fn, batched=False, aatype=gly_aatype)
with torch.no_grad():
designed_seqs = mpnn.run_proteinmpnn(
model=model,
pdb_path=pdb_fn,
num_seq_per_target=num_seqs,
omit_AAs=disallow_aas,
)
if temp_pdb:
os.system("rm " + pdb_fn)
return designed_seqs
def get_esmfold_model(device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
model.esm = model.esm.half()
return model
def inference_esmfold(sequence_list, model, tokenizer):
inputs = tokenizer(
sequence_list,
return_tensors="pt",
padding=True,
add_special_tokens=False,
).to(model.device)
outputs = model(**inputs)
# positions is shape (l, b, n, a, c)
pred_coords = outputs.positions[-1].contiguous()
plddts = (outputs.plddt[:, :, 1] * inputs.attention_mask).sum(
-1
) / inputs.attention_mask.sum(-1).clamp(min=1e-3)
return pred_coords, plddts
def predict_structures(sequences, model="esmfold", tokenizer=None, force_unk_to_X=True):
# Expects seqs like 'MKRLLDS', not aatypes
# model can be a model, or a string describing which pred model to load
if isinstance(sequences, str):
sequences = [sequences]
if model == "esmfold":
model = get_esmfold_model()
device = model.device
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
aatype = [utils.seq_to_aatype(seq).to(device) for seq in sequences]
with torch.no_grad():
if isinstance(model, EsmForProteinFolding):
pred_coords, plddts = inference_esmfold(sequences, model, tokenizer)
seq_lens = [len(s) for s in sequences]
trimmed_coords = [c[: seq_lens[i]] for i, c in enumerate(pred_coords)]
trimmed_coords_atom37 = [
utils.atom37_coords_from_atom14(c, aatype[i])
for i, c in enumerate(trimmed_coords)
]
return trimmed_coords_atom37, plddts
def compute_structure_metric(coords1, coords2, metric="ca_rmsd", atom_mask=None):
# coords1 tensor[l][a][3]
def _tmscore(a, b, mask=None):
length = len(b)
dists = (a - b).pow(2).sum(-1)
d0 = 1.24 * ((length - 15) ** (1 / 3)) - 1.8
term = 1 / (1 + ((dists) / (d0**2)))
if mask is None:
return term.mean()
else:
term = term * mask
return term.sum() / mask.sum().clamp(min=1)
aligned_coords1_ca, (R, t) = utils.kabsch_align(coords1[:, 1], coords2[:, 1])
aligned_coords1 = coords1 - coords1[:, 1:2].mean(0, keepdim=True)
aligned_coords1 = aligned_coords1 @ R.t() + t
if metric == "ca_rmsd":
return (aligned_coords1_ca - coords2[:, 1]).pow(2).sum(-1).sqrt().mean()
elif metric == "tm_score":
tm = _tmscore(aligned_coords1_ca, coords2[:, 1])
# TODO: return 1 - tm score for now so sorts work properly
return 1 - tm
elif metric == "allatom_tm":
# Align on Ca, compute allatom TM
assert atom_mask is not None
return _tmscore(aligned_coords1, coords2, mask=atom_mask)
elif metric == "allatom_lddt":
assert atom_mask is not None
lddt = modules.lddt(
coords1.reshape(-1, 3),
coords2.reshape(-1, 3),
atom_mask.reshape(-1, 1),
per_residue=False,
)
return lddt
else:
raise NotImplementedError
def compute_self_consistency(
comparison_structures, # can be sampled or ground truth
sampled_sequences=None,
mpnn_model=None,
struct_pred_model=None,
tokenizer=None,
num_seqs=1,
return_aux=False,
metric="ca_rmsd",
output_file=None,
):
# Typically used for eval of backbone sampling or sequence design or joint sampling
# (Maybe MPNN) + Fold + TM/RMSD
# Expects seqs like 'MKRLLDS', not aatypes
per_sample_primary_metrics = []
per_sample_secondary_metrics = []
per_sample_plddts = []
per_sample_coords = []
per_sample_seqs = []
aux = {}
for i, coords in enumerate(comparison_structures):
if sampled_sequences is None:
seqs_to_predict = design_sequence(
coords, model=mpnn_model, num_seqs=num_seqs
)
else:
seqs_to_predict = sampled_sequences[i]
pred_coords, plddts = predict_structures(
seqs_to_predict, model=struct_pred_model, tokenizer=tokenizer
)
primary_metric_name = "tm_score" if metric == "tm_score" else "ca_rmsd"
secondary_metric_name = "tm_score" if metric == "both" else None
primary_metrics = [
compute_structure_metric(coords.to(pred), pred, metric=primary_metric_name)
for pred in pred_coords
]
if secondary_metric_name:
secondary_metrics = [
compute_structure_metric(
coords.to(pred), pred, metric=secondary_metric_name
)
for pred in pred_coords
]
aux.setdefault(secondary_metric_name, []).extend(secondary_metrics)
else:
secondary_metrics = primary_metrics
aux.setdefault("pred", []).extend(pred_coords)
seqs_to_predict_arr = seqs_to_predict
if isinstance(seqs_to_predict_arr, str):
seqs_to_predict_arr = [seqs_to_predict_arr]
aux.setdefault("seqs", []).extend(seqs_to_predict_arr)
aux.setdefault("plddt", []).extend(plddts)
aux.setdefault("rmsd", []).extend(primary_metrics)
# Report best rmsd design only among MPNN reps
all_designs = [
(m, p, t, c, s)
for m, p, t, c, s in zip(
primary_metrics,
plddts,
secondary_metrics,
pred_coords,
seqs_to_predict_arr,
)
]
best_rmsd_design = min(all_designs, key=lambda x: x[0])
per_sample_primary_metrics.append(best_rmsd_design[0].detach().cpu())
per_sample_plddts.append(best_rmsd_design[1].detach().cpu())
per_sample_secondary_metrics.append(best_rmsd_design[2].detach().cpu())
per_sample_coords.append(best_rmsd_design[3])
per_sample_seqs.append(best_rmsd_design[4])
best_idx = np.argmin(per_sample_primary_metrics)
metrics = {
"sc_rmsd_best": per_sample_primary_metrics[best_idx],
"sc_plddt_best": per_sample_plddts[best_idx],
"sc_rmsd_mean": mean(per_sample_primary_metrics),
"sc_plddt_mean": mean(per_sample_plddts),
}
if metric == "both":
metrics["sc_tmscore_best"] = per_sample_secondary_metrics[best_idx]
metrics["sc_tmscore_mean"] = mean(per_sample_secondary_metrics)
if output_file:
pred_coords = per_sample_coords
designed_seqs = per_sample_seqs
if torch.isnan(pred_coords[best_idx]).sum() == 0:
designed_seq = utils.seq_to_aatype(designed_seqs[best_idx])
utils.write_coords_to_pdb(
pred_coords[best_idx],
output_file,
batched=False,
aatype=designed_seq,
)
if return_aux:
return metrics, best_idx, aux
else:
return metrics, best_idx
def compute_secondary_structure_content(coords_batch):
dssp_sample = []
for i, c in enumerate(coords_batch):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dssp_str = utils.get_3state_dssp(coords=c)
if dssp_str is None or len(dssp_str) == 0:
pass
else:
dssp_sample.append(dssp_str)
dssp_sample = "".join(dssp_sample)
metrics = {}
metrics["sample_pct_beta"] = mean([c == "E" for c in dssp_sample])
metrics["sample_pct_alpha"] = mean([c == "H" for c in dssp_sample])
return metrics
def compute_bond_length_metric(
cropped_coords_list, cropped_aatypes_list, atom_mask=None
):
bond_length_dict = utils.batched_fullatom_bond_lengths_from_coords(
cropped_coords_list, cropped_aatypes_list, atom_mask=atom_mask
)
all_errors = {}
for aa1, d in bond_length_dict.items():
aa3 = residue_constants.restype_1to3[aa1]
per_bond_errors = []
for bond, lengths in d.items():
a1, a2 = bond.split("-")
ideal_val = None
for bond in residue_constants.standard_residue_bonds[aa3]:
if (
bond.atom1_name == a1
and bond.atom2_name == a2
or bond.atom1_name == a2
and bond.atom2_name == a1
):
ideal_val = bond.length
break
error = (np.array(lengths) - ideal_val) ** 2
per_bond_errors.append(error.mean() ** 0.5)
if len(per_bond_errors) > 0: # often no Cys
per_res_errors = np.mean(per_bond_errors)
all_errors[aa1] = per_res_errors
return np.mean(list(all_errors.values()))
def evaluate_backbone_generation(
model,
n_samples=1,
mpnn_model=None,
struct_pred_model=None,
tokenizer=None,
sample_length_range=(50, 512),
):
sampling_config = sampling.default_backbone_sampling_config()
trimmed_coords, seq_mask = sampling.draw_backbone_samples(
model,
n_samples=n_samples,
sample_length_range=sample_length_range,
**vars(sampling_config),
)
sc_metrics, best_idx, aux = compute_self_consistency(
trimmed_coords,
mpnn_model=mpnn_model,
struct_pred_model=struct_pred_model,
tokenizer=tokenizer,
return_aux=True,
)
dssp_metrics = compute_secondary_structure_content(trimmed_coords)
all_metrics = {**sc_metrics, **dssp_metrics}
all_metrics = {f"bb_{k}": v for k, v in all_metrics.items()}
return all_metrics, (trimmed_coords, seq_mask, best_idx, aux["pred"], aux["seqs"])
def evaluate_allatom_generation(
model,
n_samples,
two_stage_sampling=True,
struct_pred_model=None,
tokenizer=None,
sample_length_range=(50, 512),
):
# Convert allatom model to codesign model by loading miniMPNN
model.task = "codesign"
model.load_minimpnn()
model.eval()
sampling_config = sampling.default_allatom_sampling_config()
ret = sampling.draw_allatom_samples(
model,
n_samples=n_samples,
two_stage_sampling=two_stage_sampling,
**vars(sampling_config),
)
(
cropped_samp_coords,
cropped_samp_aatypes,
samp_atom_mask,
stage1_coords,
seq_mask,
) = ret
# Compute self consistency
if struct_pred_model is None:
struct_pred_model = EsmForProteinFolding.from_pretrained(
"facebook/esmfold_v1"
).to(device)
struct_pred_model.esm = struct_pred_model.esm.half()
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
designed_seqs = [utils.aatype_to_seq(a) for a in cropped_samp_aatypes]
sc_metrics, best_idx, sc_aux = compute_self_consistency(
comparison_structures=cropped_samp_coords,
sampled_sequences=designed_seqs,
struct_pred_model=struct_pred_model,
tokenizer=tokenizer,
return_aux=True,
)
aa_metrics_out = {f"aa_{k}": v for k, v in sc_metrics.items()}
# Compute secondary structure content
cropped_bb_coords = [c[..., [0, 1, 2, 4], :] for c in cropped_samp_coords]
dssp_metrics = compute_secondary_structure_content(cropped_bb_coords)
aa_metrics_out = {**aa_metrics_out, **dssp_metrics}
# Compute bond length RMSE
if two_stage_sampling: # compute on original sample
bond_rmse_coords = stage1_coords
else:
bond_rmse_coords = cropped_samp_coords
bond_rmse = compute_bond_length_metric(
bond_rmse_coords, cropped_samp_aatypes, samp_atom_mask
)
aa_metrics_out["aa_bond_rmse"] = bond_rmse
# Convert codesign model back to allatom model and return metrics
model.task = "allatom"
model.remove_minimpnn()
aa_aux_out = (
cropped_samp_coords,
cropped_samp_aatypes,
samp_atom_mask,
sc_aux["pred"],
best_idx,
)
return aa_metrics_out, aa_aux_out