|
import os |
|
import time |
|
import py3Dmol |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import esm |
|
from io import BytesIO |
|
import esm.inverse_folding |
|
import requests |
|
|
|
|
|
|
|
from tqdm import tqdm |
|
from torchdrug import core, models, tasks, datasets, utils, data |
|
from torchdrug.utils import comm |
|
import sys |
|
import glob |
|
import math |
|
import pprint |
|
import random |
|
|
|
|
|
def get_pdb(seq): |
|
print(f'[LOG] Obataining pdb files: {seq}.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
url = 'https://api.esmatlas.com/foldSequence/v1/pdb/' |
|
r = requests.post(url, data=seq) |
|
pdb = r.text |
|
return pdb |
|
|
|
|
|
def get_score(transform, task, cfg, seq): |
|
print(f'[LOG] Predicting scores: {seq}.') |
|
pdb = get_pdb(seq) |
|
outpath = "data/demo/tmp_get_score.pdb" |
|
with open(outpath, "w") as f: |
|
f.write(pdb) |
|
|
|
pdb_files = [outpath] |
|
device = torch.device(cfg.gpu) |
|
|
|
task.eval() |
|
batch_size = cfg.get("batch_size", 1) |
|
preds = [] |
|
for i in tqdm(range(0, len(pdb_files), batch_size)): |
|
proteins = [] |
|
for pdb_file in pdb_files[i:i+batch_size]: |
|
protein, sequence = bio_load_pdb(pdb_file) |
|
proteins.append(protein) |
|
protein = data.Protein.pack(proteins) |
|
|
|
batch = {"graph": protein} |
|
batch = transform(batch) |
|
with torch.no_grad(): |
|
pred = task.predict(batch) |
|
for j, value in enumerate(pred.cpu().unbind()): |
|
name = os.path.basename(pdb_files[i+j])[:-4] |
|
preds.append((name, value.item())) |
|
|
|
preds = sorted(preds, key=lambda x: -x[-1]) |
|
print(preds) |
|
return preds[0][1], get_3dview(pdb) |
|
|
|
|
|
def get_3dview(pdb): |
|
view = py3Dmol.view(width=500, height=500) |
|
view.addModel(pdb, "pdb") |
|
view.setStyle({'cartoon': {'color': 'spectrum'}}) |
|
|
|
|
|
output = view._make_html().replace("'", '"') |
|
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" |
|
|
|
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
def display_pdb(sequence): |
|
|
|
|
|
view = py3Dmol.view(width=500, height=500) |
|
view.addModel(get_pdb(sequence), "pdb") |
|
view.setStyle({'cartoon': {'color': 'spectrum'}}) |
|
|
|
view.zoomTo() |
|
output = view._make_html().replace("'", '"') |
|
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" |
|
|
|
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
def display_pdb_by_pdb(pdb): |
|
|
|
|
|
view = py3Dmol.view(width=500, height=500) |
|
view.addModel(pdb, "pdb") |
|
view.setStyle({'cartoon': {'color': 'spectrum'}}) |
|
|
|
view.zoomTo() |
|
output = view._make_html().replace("'", '"') |
|
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" |
|
|
|
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
def sample_seq(sequence, chain='A',num_samples=20,temperature=1): |
|
|
|
pdbfile="data/demo/tmp_sample_seq_singlechain.pdb" |
|
with open(pdbfile, "w") as f: |
|
f.write(get_pdb(sequence)) |
|
|
|
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() |
|
model = model.eval() |
|
|
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
|
|
args, vars = util.parse_args() |
|
cfg = util.load_config(args.config, context=vars) |
|
|
|
seed = args.seed |
|
torch.manual_seed(seed + comm.get_rank()) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
if args.multichain: |
|
structure = esm.inverse_folding.util.load_structure(pdbfile) |
|
coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure) |
|
native_seq = native_seqs[chain] |
|
print('[LOG] Sampling multichain. Native sequence loaded from structure file:', native_seq) |
|
else: |
|
coords, native_seq = esm.inverse_folding.util.load_coords(pdbfile, chain) |
|
print('[LOG] Sampling singlechain. Native sequence loaded from structure file:', native_seq) |
|
|
|
|
|
dataset = core.Configurable.load_config_dict(cfg.dataset) |
|
task = core.Configurable.load_config_dict(cfg.task) |
|
task.preprocess(dataset, None, None) |
|
transform = core.Configurable.load_config_dict(cfg.transform) |
|
if cfg.get("checkpoint") is not None: |
|
cfg.checkpoint = os.path.expanduser(cfg.checkpoint) |
|
pretrained_dict = torch.load(cfg.checkpoint, map_location=torch.device('cpu'))['model'] |
|
model_dict = task.state_dict() |
|
task.load_state_dict(pretrained_dict) |
|
|
|
res = "" |
|
seq_list = [] |
|
i = 0 |
|
while i < num_samples: |
|
if args.multichain: |
|
sampled_seq = esm.inverse_folding.multichain_util.sample_sequence_in_complex( |
|
model, coords, chain, temperature=temperature) |
|
else: |
|
sampled_seq = model.sample(coords, temperature=temperature, device=torch.device('cpu')) |
|
print(f'[LOG] Sampling sequence: {sampled_seq}.') |
|
|
|
|
|
|
|
try: |
|
score, view3d = get_score(transform, task, cfg, sampled_seq) |
|
print(score) |
|
i += 1 |
|
except ValueError as ve: |
|
print(ve) |
|
continue |
|
|
|
seq_list.append([score, sampled_seq, view3d]) |
|
|
|
if len(seq_list) == 1: |
|
yield str(i)+" / " + str(num_samples), seq_list[0][1], None, None, seq_list[0][2], None, None, seq_list[0][0], None, None |
|
elif len(seq_list) == 2: |
|
seq_list = sorted(seq_list, key=lambda x: x[0]) |
|
yield str(i)+" / " + str(num_samples), seq_list[1][1], seq_list[0][1], None, seq_list[1][2], seq_list[0][2], None, seq_list[1][0], seq_list[0][0], None |
|
else: |
|
seq_list = sorted(seq_list, key=lambda x: x[0])[-3:] |
|
yield str(i)+" / "+ str(num_samples), seq_list[2][1], seq_list[1][1], seq_list[0][1], seq_list[2][2], seq_list[1][2], seq_list[0][2], seq_list[2][0], seq_list[1][0], seq_list[0][0] |
|
|
|
def show_gif(): |
|
path = 'output' |
|
pdb_files = sorted(os.listdir(path), key=lambda x: int(x.split('_')[1])) |
|
num = len(pdb_files) |
|
step = 1 |
|
i = 0 |
|
while True: |
|
if i > num: |
|
break |
|
step = int(torch.tensor(i+3).log().item()) |
|
time.sleep(0.3) |
|
p = os.path.join(path, pdb_files[i]) |
|
with open(p,'r') as f: |
|
f_pdb = f.readlines() |
|
|
|
i += step |
|
yield display_pdb_by_pdb(''.join(f_pdb)), pdb_files[i] |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
title = "Artificial Intelligence Generated Protein" |
|
|
|
css = "footer {visibility: hidden}" |
|
|
|
with gr.Blocks(title=title, css=css) as demo: |
|
output_viewer = gr.HTML() |
|
with gr.Row(): |
|
gif = gr.HTML() |
|
it = gr.Textbox(label="Iteraton") |
|
btn3 = gr.Button("GIF") |
|
btn3.click(show_gif, [], [gif, it]) |
|
|
|
demo.queue() |
|
demo.launch(show_api=False, server_name="0.0.0.0", share=True) |
|
|