import gradio as gr
import urllib
import re
import sys
import warnings
import torch
import torch.nn as nn
import ipywidgets as widgets
from ipywidgets import interact, fixed
from utils.helpers import *
from utils.voxelization import processStructures
from utils.model import Model
import numpy as np
import os
import moleculekit
print(moleculekit.__version__)
def update(inp, file, mode, custom_resids, clustering_threshold, distance_cutoff):
try:
filepath = file.name
except:
print("using pdbfile")
try:
pdb_file = inp
if (
re.match(
"[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",
pdb_file,
).group()
== pdb_file
):
urllib.request.urlretrieve(
f"https://alphafold.ebi.ac.uk/files/AF-{pdb_file}-F1-model_v2.pdb",
f"files/{pdb_file}.pdb",
)
filepath = f"files/{pdb_file}.pdb"
except AttributeError:
if len(inp) == 4:
pdb_file = inp
urllib.request.urlretrieve(
f"http://files.rcsb.org/download/{pdb_file.lower()}.pdb1",
f"files/{pdb_file}.pdb",
)
filepath = f"files/{pdb_file}.pdb"
else:
return "pdb code must be 4 letters or Uniprot code does not match", ""
identifier = os.path.basename(filepath)
if mode == "All residues":
print("using all residues")
ids = get_all_protein_resids(filepath)
elif len(custom_resids) != 0:
print("using listed residues", custom_resids)
ids = get_all_resids_from_list(filepath, custom_resids.replace(",", " "))
else:
print("using metalbinding")
ids = get_all_metalbinding_resids(filepath)
print(filepath)
print(ids)
try:
voxels, prot_centers, prot_N, prots = processStructures(filepath, ids)
except Exception as e:
print(e)
return (
"Error",
f"""
Something went wrong with the voxelization, reset custom residues and other input fiels and check error message
{e}
""",
)
voxels.to(device)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
output = model(voxels)
print(output.shape)
prot_v = np.vstack(prot_centers)
output_v = output.flatten().cpu().detach().numpy()
bb = get_bb(prot_v)
gridres = 0.5
grid, box_N = create_grid_fromBB(bb, voxelSize=gridres)
probability_values = get_probability_mean(grid, prot_v, output_v)
print(probability_values.shape)
write_cubefile(
bb,
probability_values,
box_N,
outname=f"output/metal_{identifier}.cube",
gridres=gridres,
)
message = find_unique_sites(
probability_values,
grid,
writeprobes=True,
probefile=f"output/probes_{identifier}.pdb",
threshold=distance_cutoff,
p=clustering_threshold,
)
del voxels
torch.cuda.empty_cache()
return message, molecule(
filepath,
f"output/probes_{identifier}.pdb",
f"output/metal_{identifier}.cube",
)
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb, probes, cube):
mol = read_mol(pdb)
probes = read_mol(probes)
cubefile = read_mol(cube)
x = (
"""
Isovalue
0.5
"""
)
return f""""""
def set_examples(example):
n, code, resids = example
return [n, code, resids]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.to(device)
model.load_state_dict(
torch.load(
"weights/metal_0.5A_v3_d0.2_16Abox.pth",
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
)
model.eval()
metal3d = gr.Blocks()
with metal3d:
gr.Markdown("# Metal3D")
gr.Markdown(
""" Inference using CPU-only, can be quite slow for more than 20 residues. Use [Colab notebook](https://colab.research.google.com/github/lcbc-epfl/metal-site-prediction/blob/main/Metal3D/ColabMetal.ipynb) for GPU acceleration
"""
)
with gr.Tabs():
with gr.TabItem("Input"):
inp = gr.Textbox(
placeholder="PDB Code or Uniprot identifier or upload file below",
label="Input molecule",
)
file = gr.File(file_count="single", type="file")
with gr.TabItem("Settings"):
with gr.Row():
mode = gr.Radio(
["All metalbinding residues (ASP, CYS, GLU, HIS)", "All residues"],
label="Residues to use for prediction",
)
custom_resids = gr.Textbox(
placeholder="Comma separated list of residues",
label="Custom residues",
)
with gr.Row():
clustering_threshold = gr.Slider(
minimum=0.15,
maximum=1,
value=0.15,
step=0.05,
label="Clustering threshold",
)
distance_cutoff = gr.Slider(
minimum=1,
maximum=10,
value=7,
step=0.5,
label="Clustering distance cutoff",
)
btn = gr.Button("Run")
n = gr.Textbox(label="Label", visible=False)
examples = gr.Dataset(
components=[n, inp, custom_resids],
samples=[
["HCA2", "2CBA", ""],
["Nickel in GB1 dimer", "6F5N", ""],
["Zebrafish palmitoyltransferase ZDHHC15B PDB", "6BMS", ""],
[
"Human palmitoyltransferase ZDHHC23 AlphaFold",
"Q8IYP9",
"280,273,263,260,274,277,274,287",
],
],
)
examples.click(fn=set_examples, inputs=examples, outputs=examples.components)
gr.Markdown("# Output")
out = gr.Textbox(label="status")
mol = gr.HTML()
btn.click(
fn=update,
inputs=[inp, file, mode, custom_resids, clustering_threshold, distance_cutoff],
outputs=[out, mol],
)
metal3d.launch(share=True)