Spaces:
Sleeping
Sleeping
""" | |
https://github.com/ProteinDesignLab/protpardelle | |
License: MIT | |
Author: Alex Chu | |
Utils for computing evaluation metrics. | |
""" | |
import argparse | |
import os | |
import warnings | |
from typing import Tuple | |
from Bio.Align import substitution_matrices | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, EsmForProteinFolding | |
from torchtyping import TensorType | |
from core import residue_constants | |
from core import utils | |
from core import protein_mpnn as mpnn | |
import modules | |
import sampling | |
def mean(x): | |
if len(x) == 0: | |
return 0 | |
return sum(x) / len(x) | |
def calculate_seq_identity(seq1, seq2, seq_mask=None): | |
identity = (seq1 == seq2.to(seq1)).float() | |
if seq_mask is not None: | |
identity *= seq_mask.to(seq1) | |
return identity.sum(-1) / seq_mask.to(seq1).sum(-1).clamp(min=1) | |
else: | |
return identity.mean(-1) | |
def design_sequence(coords, model=None, num_seqs=1, disallow_aas=["C"]): | |
# Returns list of strs; seqs like 'MKRLLDS', not aatypes | |
if model is None: | |
model = mpnn.get_mpnn_model() | |
if isinstance(coords, str): | |
temp_pdb = False | |
pdb_fn = coords | |
else: | |
temp_pdb = True | |
pdb_fn = f"tmp{np.random.randint(0, 1e8)}.pdb" | |
gly_idx = residue_constants.restype_order["G"] | |
gly_aatype = (torch.ones(coords.shape[0]) * gly_idx).long() | |
utils.write_coords_to_pdb(coords, pdb_fn, batched=False, aatype=gly_aatype) | |
with torch.no_grad(): | |
designed_seqs = mpnn.run_proteinmpnn( | |
model=model, | |
pdb_path=pdb_fn, | |
num_seq_per_target=num_seqs, | |
omit_AAs=disallow_aas, | |
) | |
if temp_pdb: | |
os.system("rm " + pdb_fn) | |
return designed_seqs | |
def get_esmfold_model(device=None): | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device) | |
model.esm = model.esm.half() | |
return model | |
def inference_esmfold(sequence_list, model, tokenizer): | |
inputs = tokenizer( | |
sequence_list, | |
return_tensors="pt", | |
padding=True, | |
add_special_tokens=False, | |
).to(model.device) | |
outputs = model(**inputs) | |
# positions is shape (l, b, n, a, c) | |
pred_coords = outputs.positions[-1].contiguous() | |
plddts = (outputs.plddt[:, :, 1] * inputs.attention_mask).sum( | |
-1 | |
) / inputs.attention_mask.sum(-1).clamp(min=1e-3) | |
return pred_coords, plddts | |
def predict_structures(sequences, model="esmfold", tokenizer=None, force_unk_to_X=True): | |
# Expects seqs like 'MKRLLDS', not aatypes | |
# model can be a model, or a string describing which pred model to load | |
if isinstance(sequences, str): | |
sequences = [sequences] | |
if model == "esmfold": | |
model = get_esmfold_model() | |
device = model.device | |
if tokenizer is None: | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
aatype = [utils.seq_to_aatype(seq).to(device) for seq in sequences] | |
with torch.no_grad(): | |
if isinstance(model, EsmForProteinFolding): | |
pred_coords, plddts = inference_esmfold(sequences, model, tokenizer) | |
seq_lens = [len(s) for s in sequences] | |
trimmed_coords = [c[: seq_lens[i]] for i, c in enumerate(pred_coords)] | |
trimmed_coords_atom37 = [ | |
utils.atom37_coords_from_atom14(c, aatype[i]) | |
for i, c in enumerate(trimmed_coords) | |
] | |
return trimmed_coords_atom37, plddts | |
def compute_structure_metric(coords1, coords2, metric="ca_rmsd", atom_mask=None): | |
# coords1 tensor[l][a][3] | |
def _tmscore(a, b, mask=None): | |
length = len(b) | |
dists = (a - b).pow(2).sum(-1) | |
d0 = 1.24 * ((length - 15) ** (1 / 3)) - 1.8 | |
term = 1 / (1 + ((dists) / (d0**2))) | |
if mask is None: | |
return term.mean() | |
else: | |
term = term * mask | |
return term.sum() / mask.sum().clamp(min=1) | |
aligned_coords1_ca, (R, t) = utils.kabsch_align(coords1[:, 1], coords2[:, 1]) | |
aligned_coords1 = coords1 - coords1[:, 1:2].mean(0, keepdim=True) | |
aligned_coords1 = aligned_coords1 @ R.t() + t | |
if metric == "ca_rmsd": | |
return (aligned_coords1_ca - coords2[:, 1]).pow(2).sum(-1).sqrt().mean() | |
elif metric == "tm_score": | |
tm = _tmscore(aligned_coords1_ca, coords2[:, 1]) | |
# TODO: return 1 - tm score for now so sorts work properly | |
return 1 - tm | |
elif metric == "allatom_tm": | |
# Align on Ca, compute allatom TM | |
assert atom_mask is not None | |
return _tmscore(aligned_coords1, coords2, mask=atom_mask) | |
elif metric == "allatom_lddt": | |
assert atom_mask is not None | |
lddt = modules.lddt( | |
coords1.reshape(-1, 3), | |
coords2.reshape(-1, 3), | |
atom_mask.reshape(-1, 1), | |
per_residue=False, | |
) | |
return lddt | |
else: | |
raise NotImplementedError | |
def compute_self_consistency( | |
comparison_structures, # can be sampled or ground truth | |
sampled_sequences=None, | |
mpnn_model=None, | |
struct_pred_model=None, | |
tokenizer=None, | |
num_seqs=1, | |
return_aux=False, | |
metric="ca_rmsd", | |
output_file=None, | |
): | |
# Typically used for eval of backbone sampling or sequence design or joint sampling | |
# (Maybe MPNN) + Fold + TM/RMSD | |
# Expects seqs like 'MKRLLDS', not aatypes | |
per_sample_primary_metrics = [] | |
per_sample_secondary_metrics = [] | |
per_sample_plddts = [] | |
per_sample_coords = [] | |
per_sample_seqs = [] | |
aux = {} | |
for i, coords in enumerate(comparison_structures): | |
if sampled_sequences is None: | |
seqs_to_predict = design_sequence( | |
coords, model=mpnn_model, num_seqs=num_seqs | |
) | |
else: | |
seqs_to_predict = sampled_sequences[i] | |
pred_coords, plddts = predict_structures( | |
seqs_to_predict, model=struct_pred_model, tokenizer=tokenizer | |
) | |
primary_metric_name = "tm_score" if metric == "tm_score" else "ca_rmsd" | |
secondary_metric_name = "tm_score" if metric == "both" else None | |
primary_metrics = [ | |
compute_structure_metric(coords.to(pred), pred, metric=primary_metric_name) | |
for pred in pred_coords | |
] | |
if secondary_metric_name: | |
secondary_metrics = [ | |
compute_structure_metric( | |
coords.to(pred), pred, metric=secondary_metric_name | |
) | |
for pred in pred_coords | |
] | |
aux.setdefault(secondary_metric_name, []).extend(secondary_metrics) | |
else: | |
secondary_metrics = primary_metrics | |
aux.setdefault("pred", []).extend(pred_coords) | |
seqs_to_predict_arr = seqs_to_predict | |
if isinstance(seqs_to_predict_arr, str): | |
seqs_to_predict_arr = [seqs_to_predict_arr] | |
aux.setdefault("seqs", []).extend(seqs_to_predict_arr) | |
aux.setdefault("plddt", []).extend(plddts) | |
aux.setdefault("rmsd", []).extend(primary_metrics) | |
# Report best rmsd design only among MPNN reps | |
all_designs = [ | |
(m, p, t, c, s) | |
for m, p, t, c, s in zip( | |
primary_metrics, | |
plddts, | |
secondary_metrics, | |
pred_coords, | |
seqs_to_predict_arr, | |
) | |
] | |
best_rmsd_design = min(all_designs, key=lambda x: x[0]) | |
per_sample_primary_metrics.append(best_rmsd_design[0].detach().cpu()) | |
per_sample_plddts.append(best_rmsd_design[1].detach().cpu()) | |
per_sample_secondary_metrics.append(best_rmsd_design[2].detach().cpu()) | |
per_sample_coords.append(best_rmsd_design[3]) | |
per_sample_seqs.append(best_rmsd_design[4]) | |
best_idx = np.argmin(per_sample_primary_metrics) | |
metrics = { | |
"sc_rmsd_best": per_sample_primary_metrics[best_idx], | |
"sc_plddt_best": per_sample_plddts[best_idx], | |
"sc_rmsd_mean": mean(per_sample_primary_metrics), | |
"sc_plddt_mean": mean(per_sample_plddts), | |
} | |
if metric == "both": | |
metrics["sc_tmscore_best"] = per_sample_secondary_metrics[best_idx] | |
metrics["sc_tmscore_mean"] = mean(per_sample_secondary_metrics) | |
if output_file: | |
pred_coords = per_sample_coords | |
designed_seqs = per_sample_seqs | |
if torch.isnan(pred_coords[best_idx]).sum() == 0: | |
designed_seq = utils.seq_to_aatype(designed_seqs[best_idx]) | |
utils.write_coords_to_pdb( | |
pred_coords[best_idx], | |
output_file, | |
batched=False, | |
aatype=designed_seq, | |
) | |
if return_aux: | |
return metrics, best_idx, aux | |
else: | |
return metrics, best_idx | |
def compute_secondary_structure_content(coords_batch): | |
dssp_sample = [] | |
for i, c in enumerate(coords_batch): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
dssp_str = utils.get_3state_dssp(coords=c) | |
if dssp_str is None or len(dssp_str) == 0: | |
pass | |
else: | |
dssp_sample.append(dssp_str) | |
dssp_sample = "".join(dssp_sample) | |
metrics = {} | |
metrics["sample_pct_beta"] = mean([c == "E" for c in dssp_sample]) | |
metrics["sample_pct_alpha"] = mean([c == "H" for c in dssp_sample]) | |
return metrics | |
def compute_bond_length_metric( | |
cropped_coords_list, cropped_aatypes_list, atom_mask=None | |
): | |
bond_length_dict = utils.batched_fullatom_bond_lengths_from_coords( | |
cropped_coords_list, cropped_aatypes_list, atom_mask=atom_mask | |
) | |
all_errors = {} | |
for aa1, d in bond_length_dict.items(): | |
aa3 = residue_constants.restype_1to3[aa1] | |
per_bond_errors = [] | |
for bond, lengths in d.items(): | |
a1, a2 = bond.split("-") | |
ideal_val = None | |
for bond in residue_constants.standard_residue_bonds[aa3]: | |
if ( | |
bond.atom1_name == a1 | |
and bond.atom2_name == a2 | |
or bond.atom1_name == a2 | |
and bond.atom2_name == a1 | |
): | |
ideal_val = bond.length | |
break | |
error = (np.array(lengths) - ideal_val) ** 2 | |
per_bond_errors.append(error.mean() ** 0.5) | |
if len(per_bond_errors) > 0: # often no Cys | |
per_res_errors = np.mean(per_bond_errors) | |
all_errors[aa1] = per_res_errors | |
return np.mean(list(all_errors.values())) | |
def evaluate_backbone_generation( | |
model, | |
n_samples=1, | |
mpnn_model=None, | |
struct_pred_model=None, | |
tokenizer=None, | |
sample_length_range=(50, 512), | |
): | |
sampling_config = sampling.default_backbone_sampling_config() | |
trimmed_coords, seq_mask = sampling.draw_backbone_samples( | |
model, | |
n_samples=n_samples, | |
sample_length_range=sample_length_range, | |
**vars(sampling_config), | |
) | |
sc_metrics, best_idx, aux = compute_self_consistency( | |
trimmed_coords, | |
mpnn_model=mpnn_model, | |
struct_pred_model=struct_pred_model, | |
tokenizer=tokenizer, | |
return_aux=True, | |
) | |
dssp_metrics = compute_secondary_structure_content(trimmed_coords) | |
all_metrics = {**sc_metrics, **dssp_metrics} | |
all_metrics = {f"bb_{k}": v for k, v in all_metrics.items()} | |
return all_metrics, (trimmed_coords, seq_mask, best_idx, aux["pred"], aux["seqs"]) | |
def evaluate_allatom_generation( | |
model, | |
n_samples, | |
two_stage_sampling=True, | |
struct_pred_model=None, | |
tokenizer=None, | |
sample_length_range=(50, 512), | |
): | |
# Convert allatom model to codesign model by loading miniMPNN | |
model.task = "codesign" | |
model.load_minimpnn() | |
model.eval() | |
sampling_config = sampling.default_allatom_sampling_config() | |
ret = sampling.draw_allatom_samples( | |
model, | |
n_samples=n_samples, | |
two_stage_sampling=two_stage_sampling, | |
**vars(sampling_config), | |
) | |
( | |
cropped_samp_coords, | |
cropped_samp_aatypes, | |
samp_atom_mask, | |
stage1_coords, | |
seq_mask, | |
) = ret | |
# Compute self consistency | |
if struct_pred_model is None: | |
struct_pred_model = EsmForProteinFolding.from_pretrained( | |
"facebook/esmfold_v1" | |
).to(device) | |
struct_pred_model.esm = struct_pred_model.esm.half() | |
if tokenizer is None: | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
designed_seqs = [utils.aatype_to_seq(a) for a in cropped_samp_aatypes] | |
sc_metrics, best_idx, sc_aux = compute_self_consistency( | |
comparison_structures=cropped_samp_coords, | |
sampled_sequences=designed_seqs, | |
struct_pred_model=struct_pred_model, | |
tokenizer=tokenizer, | |
return_aux=True, | |
) | |
aa_metrics_out = {f"aa_{k}": v for k, v in sc_metrics.items()} | |
# Compute secondary structure content | |
cropped_bb_coords = [c[..., [0, 1, 2, 4], :] for c in cropped_samp_coords] | |
dssp_metrics = compute_secondary_structure_content(cropped_bb_coords) | |
aa_metrics_out = {**aa_metrics_out, **dssp_metrics} | |
# Compute bond length RMSE | |
if two_stage_sampling: # compute on original sample | |
bond_rmse_coords = stage1_coords | |
else: | |
bond_rmse_coords = cropped_samp_coords | |
bond_rmse = compute_bond_length_metric( | |
bond_rmse_coords, cropped_samp_aatypes, samp_atom_mask | |
) | |
aa_metrics_out["aa_bond_rmse"] = bond_rmse | |
# Convert codesign model back to allatom model and return metrics | |
model.task = "allatom" | |
model.remove_minimpnn() | |
aa_aux_out = ( | |
cropped_samp_coords, | |
cropped_samp_aatypes, | |
samp_atom_mask, | |
sc_aux["pred"], | |
best_idx, | |
) | |
return aa_metrics_out, aa_aux_out | |