import os import time import py3Dmol import gradio as gr import numpy as np import torch import esm from io import BytesIO import esm.inverse_folding import requests # import util # from gearnet import dataset, model # from gearnet.dataset import bio_load_pdb from tqdm import tqdm from torchdrug import core, models, tasks, datasets, utils, data from torchdrug.utils import comm import sys import glob import math import pprint import random def get_pdb(seq): print(f'[LOG] Obataining pdb files: {seq}.') # model = esm.pretrained.esmfold_v1() # model = model.eval().cuda() # with torch.no_grad(): # pdb = model.infer_pdb(sequence) url = 'https://api.esmatlas.com/foldSequence/v1/pdb/' r = requests.post(url, data=seq) pdb = r.text return pdb def get_score(transform, task, cfg, seq): print(f'[LOG] Predicting scores: {seq}.') pdb = get_pdb(seq) outpath = "data/demo/tmp_get_score.pdb" with open(outpath, "w") as f: f.write(pdb) pdb_files = [outpath] device = torch.device(cfg.gpu) # task = task.cuda(device) task.eval() batch_size = cfg.get("batch_size", 1) preds = [] for i in tqdm(range(0, len(pdb_files), batch_size)): proteins = [] for pdb_file in pdb_files[i:i+batch_size]: protein, sequence = bio_load_pdb(pdb_file) proteins.append(protein) protein = data.Protein.pack(proteins) # protein = protein.cuda(device) batch = {"graph": protein} batch = transform(batch) with torch.no_grad(): pred = task.predict(batch) for j, value in enumerate(pred.cpu().unbind()): name = os.path.basename(pdb_files[i+j])[:-4] preds.append((name, value.item())) preds = sorted(preds, key=lambda x: -x[-1]) print(preds) return preds[0][1], get_3dview(pdb) def get_3dview(pdb): view = py3Dmol.view(width=500, height=500) view.addModel(pdb, "pdb") view.setStyle({'cartoon': {'color': 'spectrum'}}) # view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}}) # view.zoomTo() output = view._make_html().replace("'", '"') x = f""" {output} """ # do not use ' in this input return f"""""" def display_pdb(sequence): # function to display pdb in py3dmol view = py3Dmol.view(width=500, height=500) view.addModel(get_pdb(sequence), "pdb") view.setStyle({'cartoon': {'color': 'spectrum'}}) # view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}}) view.zoomTo() output = view._make_html().replace("'", '"') x = f""" {output} """ # do not use ' in this input return f"""""" def display_pdb_by_pdb(pdb): # function to display pdb in py3dmol view = py3Dmol.view(width=500, height=500) view.addModel(pdb, "pdb") view.setStyle({'cartoon': {'color': 'spectrum'}}) # view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}}) view.zoomTo() output = view._make_html().replace("'", '"') x = f""" {output} """ # do not use ' in this input return f"""""" def sample_seq(sequence, chain='A',num_samples=20,temperature=1): pdbfile="data/demo/tmp_sample_seq_singlechain.pdb" with open(pdbfile, "w") as f: f.write(get_pdb(sequence)) model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() model = model.eval() if torch.cuda.is_available(): model = model.cuda() args, vars = util.parse_args() cfg = util.load_config(args.config, context=vars) seed = args.seed torch.manual_seed(seed + comm.get_rank()) os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if args.multichain: structure = esm.inverse_folding.util.load_structure(pdbfile) coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure) native_seq = native_seqs[chain] print('[LOG] Sampling multichain. Native sequence loaded from structure file:', native_seq) else: coords, native_seq = esm.inverse_folding.util.load_coords(pdbfile, chain) print('[LOG] Sampling singlechain. Native sequence loaded from structure file:', native_seq) # for get_score dataset = core.Configurable.load_config_dict(cfg.dataset) task = core.Configurable.load_config_dict(cfg.task) task.preprocess(dataset, None, None) transform = core.Configurable.load_config_dict(cfg.transform) if cfg.get("checkpoint") is not None: cfg.checkpoint = os.path.expanduser(cfg.checkpoint) pretrained_dict = torch.load(cfg.checkpoint, map_location=torch.device('cpu'))['model'] model_dict = task.state_dict() task.load_state_dict(pretrained_dict) res = "" seq_list = [] i = 0 while i < num_samples: if args.multichain: sampled_seq = esm.inverse_folding.multichain_util.sample_sequence_in_complex( model, coords, chain, temperature=temperature) else: sampled_seq = model.sample(coords, temperature=temperature, device=torch.device('cpu')) print(f'[LOG] Sampling sequence: {sampled_seq}.') # score, view3d = get_score(transform, task, cfg, sampled_seq) # i += 1 try: score, view3d = get_score(transform, task, cfg, sampled_seq) print(score) i += 1 except ValueError as ve: print(ve) continue seq_list.append([score, sampled_seq, view3d]) if len(seq_list) == 1: yield str(i)+" / " + str(num_samples), seq_list[0][1], None, None, seq_list[0][2], None, None, seq_list[0][0], None, None elif len(seq_list) == 2: seq_list = sorted(seq_list, key=lambda x: x[0]) yield str(i)+" / " + str(num_samples), seq_list[1][1], seq_list[0][1], None, seq_list[1][2], seq_list[0][2], None, seq_list[1][0], seq_list[0][0], None else: seq_list = sorted(seq_list, key=lambda x: x[0])[-3:] yield str(i)+" / "+ str(num_samples), seq_list[2][1], seq_list[1][1], seq_list[0][1], seq_list[2][2], seq_list[1][2], seq_list[0][2], seq_list[2][0], seq_list[1][0], seq_list[0][0] def show_gif(): path = 'output' pdb_files = sorted(os.listdir(path), key=lambda x: int(x.split('_')[1])) num = len(pdb_files) step = 1 i = 0 while True: if i > num: break step = int(torch.tensor(i+3).log().item()) time.sleep(0.3) p = os.path.join(path, pdb_files[i]) with open(p,'r') as f: f_pdb = f.readlines() i += step yield display_pdb_by_pdb(''.join(f_pdb)), pdb_files[i] if __name__ == "__main__": title = "Artificial Intelligence Generated Protein" css = "footer {visibility: hidden}" with gr.Blocks(title=title, css=css) as demo: output_viewer = gr.HTML() with gr.Row(): gif = gr.HTML() it = gr.Textbox(label="Iteraton") btn3 = gr.Button("GIF") btn3.click(show_gif, [], [gif, it]) demo.queue() demo.launch(show_api=False, server_name="0.0.0.0", share=True) # demo.launch(show_api=False, share=True)