import os
import time
import py3Dmol
import gradio as gr
import numpy as np
import torch
import esm
from io import BytesIO
import esm.inverse_folding
import requests
# import util
# from gearnet import dataset, model
# from gearnet.dataset import bio_load_pdb
from tqdm import tqdm
from torchdrug import core, models, tasks, datasets, utils, data
from torchdrug.utils import comm
import sys
import glob
import math
import pprint
import random
def get_pdb(seq):
print(f'[LOG] Obataining pdb files: {seq}.')
# model = esm.pretrained.esmfold_v1()
# model = model.eval().cuda()
# with torch.no_grad():
# pdb = model.infer_pdb(sequence)
url = 'https://api.esmatlas.com/foldSequence/v1/pdb/'
r = requests.post(url, data=seq)
pdb = r.text
return pdb
def get_score(transform, task, cfg, seq):
print(f'[LOG] Predicting scores: {seq}.')
pdb = get_pdb(seq)
outpath = "data/demo/tmp_get_score.pdb"
with open(outpath, "w") as f:
f.write(pdb)
pdb_files = [outpath]
device = torch.device(cfg.gpu)
# task = task.cuda(device)
task.eval()
batch_size = cfg.get("batch_size", 1)
preds = []
for i in tqdm(range(0, len(pdb_files), batch_size)):
proteins = []
for pdb_file in pdb_files[i:i+batch_size]:
protein, sequence = bio_load_pdb(pdb_file)
proteins.append(protein)
protein = data.Protein.pack(proteins)
# protein = protein.cuda(device)
batch = {"graph": protein}
batch = transform(batch)
with torch.no_grad():
pred = task.predict(batch)
for j, value in enumerate(pred.cpu().unbind()):
name = os.path.basename(pdb_files[i+j])[:-4]
preds.append((name, value.item()))
preds = sorted(preds, key=lambda x: -x[-1])
print(preds)
return preds[0][1], get_3dview(pdb)
def get_3dview(pdb):
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
view.setStyle({'cartoon': {'color': 'spectrum'}})
# view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
# view.zoomTo()
output = view._make_html().replace("'", '"')
x = f""" {output} """ # do not use ' in this input
return f""""""
def display_pdb(sequence):
# function to display pdb in py3dmol
view = py3Dmol.view(width=500, height=500)
view.addModel(get_pdb(sequence), "pdb")
view.setStyle({'cartoon': {'color': 'spectrum'}})
# view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
x = f""" {output} """ # do not use ' in this input
return f""""""
def display_pdb_by_pdb(pdb):
# function to display pdb in py3dmol
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
view.setStyle({'cartoon': {'color': 'spectrum'}})
# view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
x = f""" {output} """ # do not use ' in this input
return f""""""
def sample_seq(sequence, chain='A',num_samples=20,temperature=1):
pdbfile="data/demo/tmp_sample_seq_singlechain.pdb"
with open(pdbfile, "w") as f:
f.write(get_pdb(sequence))
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()
if torch.cuda.is_available():
model = model.cuda()
args, vars = util.parse_args()
cfg = util.load_config(args.config, context=vars)
seed = args.seed
torch.manual_seed(seed + comm.get_rank())
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if args.multichain:
structure = esm.inverse_folding.util.load_structure(pdbfile)
coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)
native_seq = native_seqs[chain]
print('[LOG] Sampling multichain. Native sequence loaded from structure file:', native_seq)
else:
coords, native_seq = esm.inverse_folding.util.load_coords(pdbfile, chain)
print('[LOG] Sampling singlechain. Native sequence loaded from structure file:', native_seq)
# for get_score
dataset = core.Configurable.load_config_dict(cfg.dataset)
task = core.Configurable.load_config_dict(cfg.task)
task.preprocess(dataset, None, None)
transform = core.Configurable.load_config_dict(cfg.transform)
if cfg.get("checkpoint") is not None:
cfg.checkpoint = os.path.expanduser(cfg.checkpoint)
pretrained_dict = torch.load(cfg.checkpoint, map_location=torch.device('cpu'))['model']
model_dict = task.state_dict()
task.load_state_dict(pretrained_dict)
res = ""
seq_list = []
i = 0
while i < num_samples:
if args.multichain:
sampled_seq = esm.inverse_folding.multichain_util.sample_sequence_in_complex(
model, coords, chain, temperature=temperature)
else:
sampled_seq = model.sample(coords, temperature=temperature, device=torch.device('cpu'))
print(f'[LOG] Sampling sequence: {sampled_seq}.')
# score, view3d = get_score(transform, task, cfg, sampled_seq)
# i += 1
try:
score, view3d = get_score(transform, task, cfg, sampled_seq)
print(score)
i += 1
except ValueError as ve:
print(ve)
continue
seq_list.append([score, sampled_seq, view3d])
if len(seq_list) == 1:
yield str(i)+" / " + str(num_samples), seq_list[0][1], None, None, seq_list[0][2], None, None, seq_list[0][0], None, None
elif len(seq_list) == 2:
seq_list = sorted(seq_list, key=lambda x: x[0])
yield str(i)+" / " + str(num_samples), seq_list[1][1], seq_list[0][1], None, seq_list[1][2], seq_list[0][2], None, seq_list[1][0], seq_list[0][0], None
else:
seq_list = sorted(seq_list, key=lambda x: x[0])[-3:]
yield str(i)+" / "+ str(num_samples), seq_list[2][1], seq_list[1][1], seq_list[0][1], seq_list[2][2], seq_list[1][2], seq_list[0][2], seq_list[2][0], seq_list[1][0], seq_list[0][0]
def show_gif():
path = 'output'
pdb_files = sorted(os.listdir(path), key=lambda x: int(x.split('_')[1]))
num = len(pdb_files)
step = 1
i = 0
while True:
if i > num:
break
step = int(torch.tensor(i+3).log().item())
time.sleep(0.3)
p = os.path.join(path, pdb_files[i])
with open(p,'r') as f:
f_pdb = f.readlines()
i += step
yield display_pdb_by_pdb(''.join(f_pdb)), pdb_files[i]
if __name__ == "__main__":
title = "Artificial Intelligence Generated Protein"
css = "footer {visibility: hidden}"
with gr.Blocks(title=title, css=css) as demo:
output_viewer = gr.HTML()
with gr.Row():
gif = gr.HTML()
it = gr.Textbox(label="Iteraton")
btn3 = gr.Button("GIF")
btn3.click(show_gif, [], [gif, it])
demo.queue()
demo.launch(show_api=False, server_name="0.0.0.0", share=True)
# demo.launch(show_api=False, share=True)