import json, time, os, sys, glob import gradio as gr sys.path.append("/home/user/app/ProteinMPNN/vanilla_proteinmpnn") sys.path.append("/home/duerr/phd/08_Code/ProteinMPNN/ProteinMPNN/vanilla_proteinmpnn") import matplotlib.pyplot as plt import shutil import warnings import numpy as np import torch from torch import optim from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split, Subset import copy import torch.nn as nn import torch.nn.functional as F import random import os import os.path import plotly.express as px import urllib import jax.numpy as jnp import tensorflow as tf if "/home/user/app/af_backprop" not in sys.path: sys.path.append("/home/user/app/af_backprop") # local only if "/home/duerr/phd/08_Code/ProteinMPNN/af_backprop" not in sys.path: sys.path.append("/home/duerr/phd/08_Code/ProteinMPNN/af_backprop") from utils import * # import libraries import colabfold as cf from alphafold.common import protein from alphafold.data import pipeline from alphafold.model import data, config from alphafold.model import model as afmodel from alphafold.common import residue_constants import plotly.graph_objects as go import ray import re import numpy as np import jax tf.config.set_visible_devices([], "GPU") def chain_break(idx_res, Ls, length=200): # Minkyung's code # add big enough number to residue index to indicate chain breaks L_prev = 0 for L_i in Ls[:-1]: idx_res[L_prev + L_i :] += length L_prev += L_i return idx_res def clear_mem(): backend = jax.lib.xla_bridge.get_backend() for buf in backend.live_buffers(): buf.delete() def setup_af(seq, model_name="model_5_ptm"): clear_mem() # setup model cfg = config.model_config("model_5_ptm") cfg.model.num_recycle = 0 cfg.data.common.num_recycle = 0 cfg.data.eval.max_msa_clusters = 1 cfg.data.common.max_extra_msa = 1 cfg.data.eval.masked_msa_replace_fraction = 0 cfg.model.global_config.subbatch_size = None if os.path.exists("/home/duerr"): datadir = "/home/duerr/phd/08_Code/ProteinMPNN" else: datadir = "/home/user/app/" model_params = data.get_model_haiku_params(model_name=model_name, data_dir=datadir) model_runner = afmodel.RunModel(cfg, model_params, is_training=False) Ls = [len(s) for s in seq.split("/")] seq = re.sub("[^A-Z]", "", seq.upper()) length = len(seq) feature_dict = { **pipeline.make_sequence_features( sequence=seq, description="none", num_res=length ), **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0] * length]]), } feature_dict["residue_index"] = chain_break(feature_dict["residue_index"], Ls) inputs = model_runner.process_features(feature_dict, random_seed=0) def runner(seq, opt): # update sequence inputs = opt["inputs"] inputs.update(opt["prev"]) update_seq(seq, inputs) update_aatype(inputs["target_feat"][..., 1:], inputs) # mask prediction mask = seq.sum(-1) inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask) inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask) inputs["residue_index"] = jnp.where(mask == 1, inputs["residue_index"], 0) # get prediction key = jax.random.PRNGKey(0) outputs = model_runner.apply(opt["params"], key, inputs) prev = { "init_msa_first_row": outputs["representations"]["msa_first_row"][None], "init_pair": outputs["representations"]["pair"][None], "init_pos": outputs["structure_module"]["final_atom_positions"][None], } aux = { "final_atom_positions": outputs["structure_module"]["final_atom_positions"], "final_atom_mask": outputs["structure_module"]["final_atom_mask"], "plddt": get_plddt(outputs), "pae": get_pae(outputs), "inputs": inputs, "prev": prev, } return aux return jax.jit(runner), {"inputs": inputs, "params": model_params} def make_tied_positions_for_homomers(pdb_dict_list): my_dict = {} for result in pdb_dict_list: all_chain_list = sorted( [item[-1:] for item in list(result) if item[:9] == "seq_chain"] ) # A, B, C, ... tied_positions_list = [] chain_length = len(result[f"seq_chain_{all_chain_list[0]}"]) for i in range(1, chain_length + 1): temp_dict = {} for j, chain in enumerate(all_chain_list): temp_dict[chain] = [i] # needs to be a list tied_positions_list.append(temp_dict) my_dict[result["name"]] = tied_positions_list return my_dict def align_structures(pdb1, pdb2, lenRes): """Take two structure and superimpose pdb1 on pdb2""" import Bio.PDB import subprocess pdb_parser = Bio.PDB.PDBParser(QUIET=True) # Get the structures ref_structure = pdb_parser.get_structure("samle", pdb1) sample_structure = pdb_parser.get_structure("reference", pdb2) aligner = Bio.PDB.CEAligner() aligner.set_reference(ref_structure) aligner.align(sample_structure) io = Bio.PDB.PDBIO() io.set_structure(ref_structure) io.save(f"reference.pdb") # Doing this to get around biopython CEALIGN bug subprocess.call("pymol -c -Q -r cealign.pml", shell=True) return aligner.rms, "reference.pdb", "out_aligned.pdb" def save_pdb(outs, filename, LEN): """save pdb coordinates""" p = { "residue_index": outs["inputs"]["residue_index"][0][:LEN], "aatype": outs["inputs"]["aatype"].argmax(-1)[0][:LEN], "atom_positions": outs["final_atom_positions"][:LEN], "atom_mask": outs["final_atom_mask"][:LEN], } b_factors = 100.0 * outs["plddt"][:LEN, None] * p["atom_mask"] p = protein.Protein(**p, b_factors=b_factors) pdb_lines = protein.to_pdb(p) with open(filename, "w") as f: f.write(pdb_lines) @ray.remote(num_gpus=1, max_calls=1) def run_alphafold(sequence, num_recycles): recycles = num_recycles RUNNER, OPT = setup_af(sequence) SEQ = re.sub("[^A-Z]", "", sequence.upper()) MAX_LEN = len(SEQ) LEN = len(SEQ) x = np.array([residue_constants.restype_order.get(aa, -1) for aa in SEQ]) x = np.pad(x, [0, MAX_LEN - LEN], constant_values=-1) x = jax.nn.one_hot(x, 20) OPT["prev"] = { "init_msa_first_row": np.zeros([1, MAX_LEN, 256]), "init_pair": np.zeros([1, MAX_LEN, MAX_LEN, 128]), "init_pos": np.zeros([1, MAX_LEN, 37, 3]), } positions = [] plddts = [] for r in range(recycles + 1): outs = RUNNER(x, OPT) outs = jax.tree_map(lambda x: np.asarray(x), outs) positions.append(outs["prev"]["init_pos"][0, :LEN]) plddts.append(outs["plddt"][:LEN]) OPT["prev"] = outs["prev"] if recycles > 0: print(r, plddts[-1].mean()) if os.path.exists("/home/duerr/phd/08_Code/ProteinMPNN"): save_pdb(outs, "/home/duerr/phd/08_Code/ProteinMPNN/out.pdb", LEN) else: save_pdb(outs, "/home/user/app/out.pdb", LEN) return plddts, outs["pae"], LEN if os.path.exists("/home/duerr/phd/08_Code/ProteinMPNN"): path_to_model_weights = "/home/duerr/phd/08_Code/ProteinMPNN/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights" else: path_to_model_weights = ( "/home/user/app/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights" ) def setup_proteinmpnn(model_name="v_48_020", backbone_noise=0.00): from protein_mpnn_utils import ( loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB, ) from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") # ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise # Standard deviation of Gaussian noise to add to backbone atoms hidden_dim = 128 num_layers = 3 model_folder_path = path_to_model_weights if model_folder_path[-1] != "/": model_folder_path = model_folder_path + "/" checkpoint_path = model_folder_path + f"{model_name}.pt" checkpoint = torch.load(checkpoint_path, map_location=device) noise_level_print = checkpoint["noise_level"] model = ProteinMPNN( num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint["num_edges"], ) model.to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, device def get_pdb(pdb_code="", filepath=""): if pdb_code is None or pdb_code == "": try: return filepath.name except AttributeError as e: return None else: os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") return f"{pdb_code}.pdb" def update( inp, file, designed_chain, fixed_chain, homomer, num_seqs, sampling_temp, model_name, backbone_noise, ): from protein_mpnn_utils import ( loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB, ) from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN pdb_path = get_pdb(pdb_code=inp, filepath=file) if pdb_path == None: return "Error processing PDB" model, device = setup_proteinmpnn( model_name=model_name, backbone_noise=backbone_noise ) if designed_chain == "": designed_chain_list = [] else: designed_chain_list = re.sub("[^A-Za-z]+", ",", designed_chain).split(",") if fixed_chain == "": fixed_chain_list = [] else: fixed_chain_list = re.sub("[^A-Za-z]+", ",", fixed_chain).split(",") chain_list = list(set(designed_chain_list + fixed_chain_list)) num_seq_per_target = num_seqs save_score = 0 # 0 for False, 1 for True; save score=-log_prob to npy files save_probs = ( 0 # 0 for False, 1 for True; save MPNN predicted probabilites per position ) score_only = 0 # 0 for False, 1 for True; score input backbone-sequence pairs conditional_probs_only = 0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone) conditional_probs_only_backbone = 0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone) batch_size = 1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory max_length = 20000 # Max sequence length out_folder = "." # Path to a folder to output sequences, e.g. /home/out/ jsonl_path = "" # Path to a folder with parsed pdb into jsonl omit_AAs = "X" # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine. pssm_multi = 0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions pssm_threshold = 0.0 # A value between -inf + inf to restric per position AAs pssm_log_odds_flag = 0 # 0 for False, 1 for True pssm_bias_flag = 0 # 0 for False, 1 for True folder_for_outputs = out_folder NUM_BATCHES = num_seq_per_target // batch_size BATCH_COPIES = batch_size temperatures = [sampling_temp] omit_AAs_list = omit_AAs alphabet = "ACDEFGHIKLMNPQRSTVWYX" omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32) chain_id_dict = None fixed_positions_dict = None pssm_dict = None omit_AA_dict = None bias_AA_dict = None bias_by_res_dict = None bias_AAs_np = np.zeros(len(alphabet)) ############################################################### pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list) dataset_valid = StructureDatasetPDB( pdb_dict_list, truncate=None, max_length=max_length ) if homomer: tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list) else: tied_positions_dict = None chain_id_dict = {} chain_id_dict[pdb_dict_list[0]["name"]] = (designed_chain_list, fixed_chain_list) with torch.no_grad(): for ix, prot in enumerate(dataset_valid): score_list = [] all_probs_list = [] all_log_probs_list = [] S_sample_list = [] batch_clones = [copy.deepcopy(prot) for i in range(BATCH_COPIES)] ( X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta, ) = tied_featurize( batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ) pssm_log_odds_mask = ( pssm_log_odds_all > pssm_threshold ).float() # 1.0 for true, 0.0 for false name_ = batch_clones[0]["name"] randn_1 = torch.randn(chain_M.shape, device=X.device) log_probs = model( X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1, ) mask_for_loss = mask * chain_M * chain_M_pos scores = _scores(S, log_probs, mask_for_loss) native_score = scores.cpu().data.numpy() message = "" seq_list = [] for temp in temperatures: for j in range(NUM_BATCHES): randn_2 = torch.randn(chain_M.shape, device=X.device) if tied_positions_dict == None: sample_dict = model.sample( X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all, ) S_sample = sample_dict["S"] else: sample_dict = model.tied_sample( X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all, ) # Compute scores S_sample = sample_dict["S"] log_probs = model( X, S_sample, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"], ) mask_for_loss = mask * chain_M * chain_M_pos scores = _scores(S_sample, log_probs, mask_for_loss) scores = scores.cpu().data.numpy() all_probs_list.append(sample_dict["probs"].cpu().data.numpy()) all_log_probs_list.append(log_probs.cpu().data.numpy()) S_sample_list.append(S_sample.cpu().data.numpy()) for b_ix in range(BATCH_COPIES): masked_chain_length_list = masked_chain_length_list_list[b_ix] masked_list = masked_list_list[b_ix] seq_recovery_rate = torch.sum( torch.sum( torch.nn.functional.one_hot(S[b_ix], 21) * torch.nn.functional.one_hot(S_sample[b_ix], 21), axis=-1, ) * mask_for_loss[b_ix] ) / torch.sum(mask_for_loss[b_ix]) seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix]) score = scores[b_ix] score_list.append(score) native_seq = _S_to_seq(S[b_ix], chain_M[b_ix]) if b_ix == 0 and j == 0 and temp == temperatures[0]: start = 0 end = 0 list_of_AAs = [] for mask_l in masked_chain_length_list: end += mask_l list_of_AAs.append(native_seq[start:end]) start = end native_seq = "".join( list(np.array(list_of_AAs)[np.argsort(masked_list)]) ) l0 = 0 for mc_length in list( np.array(masked_chain_length_list)[ np.argsort(masked_list) ] )[:-1]: l0 += mc_length native_seq = native_seq[:l0] + "/" + native_seq[l0:] l0 += 1 sorted_masked_chain_letters = np.argsort( masked_list_list[0] ) print_masked_chains = [ masked_list_list[0][i] for i in sorted_masked_chain_letters ] sorted_visible_chain_letters = np.argsort( visible_list_list[0] ) print_visible_chains = [ visible_list_list[0][i] for i in sorted_visible_chain_letters ] native_score_print = np.format_float_positional( np.float32(native_score.mean()), unique=False, precision=4, ) line = ">{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n".format( name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq, ) message += f"{line}\n" start = 0 end = 0 list_of_AAs = [] for mask_l in masked_chain_length_list: end += mask_l list_of_AAs.append(seq[start:end]) start = end seq = "".join( list(np.array(list_of_AAs)[np.argsort(masked_list)]) ) # add non designed chains to predicted sequence l0 = 0 for mc_length in list( np.array(masked_chain_length_list)[np.argsort(masked_list)] )[:-1]: l0 += mc_length seq = seq[:l0] + "/" + seq[l0:] l0 += 1 score_print = np.format_float_positional( np.float32(score), unique=False, precision=4 ) seq_rec_print = np.format_float_positional( np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4, ) chain_s = "" if len(visible_list_list[0]) > 0: chain_M_bool = chain_M.bool() not_designed = _S_to_seq(S[b_ix], ~chain_M_bool[b_ix]) labels = ( chain_encoding_all[b_ix][~chain_M_bool[b_ix]] .detach() .cpu() .numpy() ) for c in set(labels): chain_s += "/" nd_mask = labels == c for i, x in enumerate(not_designed): if nd_mask[i]: chain_s += x line = ( ">T={}, sample={}, score={}, seq_recovery={}\n{}\n".format( temp, b_ix, score_print, seq_rec_print, seq ) ) seq_list.append(seq + chain_s) message += f"{line}\n" # somehow sequences still contain X, remove again for i, x in enumerate(seq_list): for aa in omit_AAs: seq_list[i] = x.replace(aa, "") all_probs_concat = np.concatenate(all_probs_list) all_log_probs_concat = np.concatenate(all_log_probs_list) np.savetxt("all_probs_concat.csv", all_probs_concat.mean(0).T, delimiter=",") np.savetxt( "all_log_probs_concat.csv", np.exp(all_log_probs_concat).mean(0).T, delimiter=",", ) S_sample_concat = np.concatenate(S_sample_list) fig = px.imshow( np.exp(all_log_probs_concat).mean(0).T, labels=dict(x="positions", y="amino acids", color="probability"), y=list(alphabet), template="simple_white", ) fig.update_xaxes(side="top") fig_tadjusted = px.imshow( all_probs_concat.mean(0).T, labels=dict(x="positions", y="amino acids", color="probability"), y=list(alphabet), template="simple_white", ) fig_tadjusted.update_xaxes(side="top") return ( message, fig, fig_tadjusted, gr.File.update(value="all_log_probs_concat.csv", visible=True), gr.File.update(value="all_probs_concat.csv", visible=True), pdb_path, gr.Dropdown.update(choices=seq_list), ) def update_AF(startsequence, pdb, num_recycles): # # run alphafold using ray # plddts, pae, num_res = run_alphafold( # startsequence, num_recycles # ) if len(startsequence) > 700: return ( """
AF2 code is experimental and relies on @sokrypton's trick to speed up compile/module runtime. Results might differ from DeepMind's published results.
Predictions are made using model_5_ptm
and without MSA based on the selected single sequence (designed_chain
+ fixed_chain
).