import gradio as gr import re import urllib import tempfile from output_helpers import viewer_html, output_html, load_js, get_js import json import os import shlex import subprocess from datetime import datetime from einops import repeat import torch from core import data from core import utils import models import sampling # from draw_samples import draw_and_save_samples, parse_resample_idx_string def draw_and_save_samples( model, samples_per_len=8, lengths=range(50, 512), save_dir="./", mode="backbone", **sampling_kwargs, ): device = model.device sample_files = [] if mode == "backbone": total_sampling_time = 0 for l in lengths: prot_lens = torch.ones(samples_per_len).long() * l seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens) aux = sampling.draw_backbone_samples( model, seq_mask=seq_mask, pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp", return_aux=True, return_sampling_runtime=True, **sampling_kwargs, ) total_sampling_time += aux["runtime"] sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)] return sample_files elif mode == "allatom": total_sampling_time = 0 for l in lengths: prot_lens = torch.ones(samples_per_len).long() * l seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens) aux = sampling.draw_allatom_samples( model, seq_mask=seq_mask, pdb_save_path=f"{save_dir}/len{format(l, '03d')}", return_aux=True, **sampling_kwargs, ) total_sampling_time += aux["runtime"] sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)] return sample_files def parse_idx_string(idx_str): spans = idx_str.split(",") idxs = [] for s in spans: if "-" in s: start, stop = s.split("-") idxs.extend(list(range(int(start), int(stop)))) else: idxs.append(int(s)) return idxs def changemode(m): if (m == "unconditional"): return gr.update(visible=True), gr.update(visible=False),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=True),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) def fileselection(val): if (val == "upload"): return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) def update_structuresel(pdb, radio_val): pdb_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb") representations = [{ "model": 0, "chain": "", "resname": "", "style": "cartoon", "color": "whiteCarbon", "residue_range": "", "around": 0, "byres": False, "visible": False, }] if (radio_val == "PDB"): if (len(pdb) != 4): return gr.update(open=True),gr.update(), gr.update(value="",visible=False) else: urllib.request.urlretrieve( f"http://files.rcsb.org/download/{pdb.lower()}.pdb1", pdb_file.name, ) return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""""",visible=True) elif (radio_val == "AFDB2"): if (re.match("[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",pdb) != None): urllib.request.urlretrieve( f"https://alphafold.ebi.ac.uk/files/AF-{pdb}-F1-model_v2.pdb", pdb_file.name ) return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""""",visible=True) else: return gr.update(open=True), gr.update(value="regex not matched",visible=True) else: return gr.update(open=False),gr.update(value=f"{pdb.name}"), gr.update(value=f"""""",visible=True) from Bio.PDB import PDBParser, cealign from Bio.PDB.PDBIO import PDBIO class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen): # Set up params, arguments, sampling config #################### args = {} args["model_checkpoint"] = "checkpoints" #Path to denoiser model weights and config", args["mpnnpath"] = "checkpoints/minimpnn_state_dict.pth" #"Path to minimpnn model weights", args["modeldir"] = None #"Model base directory, ex 'training_logs/other/lemon-shape-51'", args["modelepoch"] = None #"Model epoch, ex 1000") args["type"]=modeltype # "Type of model" if m == "conditional": args["param"] = None #"Which sampling param to vary" args["paramval"]=None #"Which param val to use" args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both", args["perlen"] = int(perlen) #How many samples per sequence length" args["minlen"] = None #"Minimum sequence length" args["maxlen"] = None #Maximum sequence length, not inclusive", args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc", args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at", args["targetdir"] = "." #"Directory to save results" args["input_pdb"] = path_to_file # "PDB file to condition on" args["resample_idxs"] = resample_idx[1:-1] # "Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7" else: args["param"] = "n_steps" #"Which sampling param to vary" args["paramval"]="100" #"Which param val to use" args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both", args["perlen"] = int(perlen) #How many samples per sequence length" args["minlen"] = int(minlen) #"Minimum sequence length" args["maxlen"] = int(maxlen)+1 #Maximum sequence length args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc", args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at", args["targetdir"] = "." #"Directory to save results" args["resample_idxs"] = None args = dotdict(args) is_test_run = False seed = 0 samples_per_len = args.perlen min_len = args.minlen max_len = args.maxlen len_step_size = args.steplen device = "cuda:0" # setting default sampling config if args.type == "backbone": sampling_config = sampling.default_backbone_sampling_config() elif args.type == "allatom": sampling_config = sampling.default_allatom_sampling_config() sampling_kwargs = vars(sampling_config) # Parse conditioning inputs input_pdb_len = None if args.input_pdb: input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True) input_pdb_len = input_feats["aatype"].shape[0] if args.resample_idxs: print( f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths." ) resample_idxs = parse_idx_string(args.resample_idxs) else: resample_idxs = list(range(input_pdb_len)) cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs] to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to( device ) # For unconditional model, center coords on whole structure centered_coords = data.apply_random_se3( input_feats["atom_positions"], atom_mask=input_feats["atom_mask"], translation_scale=0.0, ) cond_kwargs = {} cond_kwargs["gt_coords"] = to_batch_size(centered_coords) cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"]) cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0 cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"]) cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"]) cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1 sampling_kwargs.update(cond_kwargs) print("input_pdb_len", input_pdb_len) # Determine lengths to sample at if min_len is not None and max_len is not None: if len_step_size is not None: sampling_lengths = range(min_len, max_len, len_step_size) else: sampling_lengths = list( torch.randint(min_len, max_len, size=(args.num_lens,)) ) elif input_pdb_len is not None: sampling_lengths = [input_pdb_len] else: raise Exception("Need to provide a set of protein lengths or an input pdb.") total_num_samples = len(list(sampling_lengths)) * samples_per_len model_directory = args.modeldir epoch = args.modelepoch base_dir = args.targetdir date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S") if is_test_run: date_string = f"test-{date_string}" # Update sampling config with arguments if args.param: var_param = args.param var_value = args.paramval sampling_kwargs[var_param] = ( None if var_value == "None" else int(var_value) if var_param == "n_steps" else float(var_value) ) elif args.parampath: with open(args.parampath) as f: var_params = json.loads(f.read()) sampling_kwargs.update(var_params) # this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule sampling_kwargs_readme = list(sampling_kwargs.items()) print("Base directory:", base_dir) save_dir = f"{base_dir}/samples/{date_string}" save_init_dir = f"{base_dir}/samples_inits/{date_string}" # make dirs if do not exist if not os.path.exists(save_dir): subprocess.run(shlex.split(f"mkdir -p {save_dir}")) if not os.path.exists(save_init_dir): subprocess.run(shlex.split(f"mkdir -p {save_init_dir}")) print("Samples saved to:", save_dir) torch.manual_seed(seed) # Load model if args.type == "backbone": if args.model_checkpoint: checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth" cfg_path = f"{args.model_checkpoint}/backbone.yml" else: checkpoint = ( f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth" ) cfg_path = f"{model_directory}/configs/backbone.yml" cfg = utils.load_config(cfg_path) weights = torch.load(checkpoint, map_location=device)["model_state_dict"] model = models.Protpardelle(cfg, device=device) model.load_state_dict(weights) model.to(device) model.eval() model.device = device elif args.type == "allatom": if args.model_checkpoint: checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth" cfg_path = f"{args.model_checkpoint}/allatom.yml" else: checkpoint = ( f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth" ) cfg_path = f"{model_directory}/configs/allatom.yml" config = utils.load_config(cfg_path) weights = torch.load(checkpoint, map_location=device)["model_state_dict"] model = models.Protpardelle(config, device=device) model.load_state_dict(weights) model.load_minimpnn(args.mpnnpath) model.to(device) model.eval() model.device = device with open(save_dir + "/run_parameters.txt", "w") as f: f.write(f"Sampling run for {date_string}\n") f.write(f"Random seed {seed}\n") f.write(f"Model checkpoint: {checkpoint}\n") f.write( f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n" ) f.write("Sampling params:\n") for k, v in sampling_kwargs_readme: f.write(f"{k}\t{v}\n") # Draw samples output_files = draw_and_save_samples( model, samples_per_len=samples_per_len, lengths=sampling_lengths, save_dir=save_dir, mode=args.type, **sampling_kwargs, ) return output_files def api_predict(pdb_content,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen): if (m == "conditional"): tempPDB = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb") tempPDB.write(pdb_content.encode()) tempPDB.close() path_to_file = tempPDB.name else: path_to_file = None try: designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen) except Exception as e: print(e) raise gr.Error(e) # load each design as string design_str = [] for d in designs: with open(d, "r") as f: design_str.append(f.read()) results = list(zip(designs, design_str)) return json.dumps(results) def predict(pdb_radio, path_to_file,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen): print("running predict") try: designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen) except Exception as e: print(e) raise gr.Error(e) return gr.update(open=True), gr.update(value="something went wrong") parser = PDBParser() aligner = cealign.CEAligner() io=PDBIO() aligned_designs = [] metrics = [] if (m == "conditional"): ref = parser.get_structure("ref", path_to_file) aligner.set_reference(ref) for d in designs: design = parser.get_structure("design", d) aligner.align(design) metrics.append({"rms": f"{aligner.rms:.1f}", "len": len(list(design[0].get_residues()))}) io.set_structure(design) io.save(d.replace(".pdb", f"_al.pdb")) aligned_designs.append(d.replace(".pdb", f"_al.pdb")) else: for d in designs: design = parser.get_structure("design", d) metrics.append({"len": len(list(design[0].get_residues()))}) aligned_designs = designs output_view = f"""""" return gr.update(open=False), gr.update(value=output_view,visible=True) protpardelleDemo = gr.Blocks() with protpardelleDemo: gr.Markdown("# Protpardelle") gr.Markdown(""" An all-atom protein generative model Alexander E. Chu, Lucy Cheng, Gina El Nesr, Minkai Xu, Po-Ssu Huang doi: https://doi.org/10.1101/2023.05.24.542194""") with gr.Accordion(label="Input options", open=True) as input_accordion: model = gr.Dropdown(["backbone", "allatom"], value="allatom", label="What to sample?") m = gr.Radio(['unconditional','conditional'],value="unconditional", label="Choose a Mode") #unconditional with gr.Group(visible=True) as uncond: gr.Markdown("Unconditional Sampling") # length = gr.Slider(minimum=0, maximum=200, step=1, value=50, label="length") # param = gr.Dropdown(["length", "param"], value="length", label="Which sampling param to vary?") # paramval = gr.Dropdown(["nsteps"], label="paramval", info="Which param val to use?") #conditional with gr.Group(visible=False) as cond: with gr.Accordion(label="Structure to condition on", open=True) as input_accordion: pdb_radio = gr.Radio(['PDB','AF2 EBI DB', 'upload'],value="PDB", label="source of the structure") pdbcode = gr.Textbox(label="Uniprot code to be retrieved Alphafold2 Database", visible=True) pdbfile = gr.File(label="PDB File", visible=False) btn_load = gr.Button("Load PDB") pdb_radio.change(fileselection, inputs=pdb_radio, outputs=[pdbcode, pdbfile, btn_load]) pdb_html = gr.HTML("", visible=False) path_to_file = gr.Textbox(label="Path to file", visible=False) resample_idxs = gr.Textbox(label="Cond Idxs", interactive=False, info="Zero indexed list of indices to condition on, select in sequence viewer above") btn_load.click(update_structuresel, inputs=[pdbcode, pdb_radio], outputs=[input_accordion,path_to_file,pdb_html]) pdbfile.change(update_structuresel, inputs=[pdbfile,pdb_radio], outputs=[input_accordion,path_to_file,pdb_html]) with gr.Accordion(label="Sizes", open=True) as size_uncond: with gr.Row(): minlen = gr.Slider(minimum=2, maximum=200,value=50, step=1, label="minlen", info="Minimum sequence length") maxlen = gr.Slider(minimum=3, maximum=200,value=60, step=1, label="maxlen", info="Maximum sequence length") steplen = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="steplen", info="How frequently to select sequence length?" ) perlen = gr.Slider(minimum=1, maximum=200, step=1, value=2, label="perlen", info="How many samples per sequence length?") btn_conditional = gr.Button("Run conditional",visible=False) btn_unconditional = gr.Button("Run unconditional") m.change(changemode, inputs=m, outputs=[uncond, cond, btn_unconditional, btn_conditional, size_uncond]) out = gr.HTML("", visible=True) btn_unconditional.click(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out]) btn_conditional.click(fn=None, inputs=[resample_idxs], outputs=[resample_idxs], _js=get_js ) # out_text = gr.Textbox(label="Output", visible=False) #hidden button for named api route pdb_content = gr.Textbox(label="PDB Content", visible=False) btn_api = gr.Button("Run API",visible=False) btn_api.click(api_predict, inputs=[pdb_content,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[out_text], api_name="protpardelle") resample_idxs.change(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out]) protpardelleDemo.load(None, None, None, _js=load_js) protpardelleDemo.queue() protpardelleDemo.launch(allowed_paths=['samples'])