Bisect_iitm_submission_2 / inference_app.py
Sukanyaaa's picture
Update inference_app.py
d7f69ca verified
raw
history blame
8.71 kB
import time
import json
import gradio as gr
from gradio_molecule3d import Molecule3D
import torch
from torch_geometric.data import HeteroData
import numpy as np
from loguru import logger
from Bio import PDB
from Bio.PDB.PDBIO import PDBIO
from pinder.core.loader.geodata import structure2tensor
from pinder.core.loader.structure import Structure
from src.models.pinder_module import PinderLitModule
try:
from torch_cluster import knn_graph
torch_cluster_installed = True
except ImportError:
logger.warning(
"torch-cluster is not installed!"
"Please install the appropriate library for your pytorch installation."
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
)
torch_cluster_installed = False
def get_props_pdb(pdb_file):
structure = Structure.read_pdb(pdb_file)
atom_mask = np.isin(getattr(structure, "atom_name"), list(["CA"]))
calpha = structure[atom_mask].copy()
props = structure2tensor(
atom_coordinates=structure.coord,
atom_types=structure.atom_name,
element_types=structure.element,
residue_coordinates=calpha.coord,
residue_types=calpha.res_name,
residue_ids=calpha.res_id,
)
return props
def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")):
props_ligand = get_props_pdb(pdb_1)
props_receptor = get_props_pdb(pdb_2)
data = HeteroData()
data["ligand"].x = props_ligand["atom_types"]
data["ligand"].pos = props_ligand["atom_coordinates"]
data["ligand", "ligand"].edge_index = knn_graph(data["ligand"].pos, k=k)
data["receptor"].x = props_receptor["atom_types"]
data["receptor"].pos = props_receptor["atom_coordinates"]
data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k)
data = data.to(device)
return data
def update_pdb_coordinates_from_tensor(
input_filename, output_filename, coordinates_tensor
):
r"""
Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
Parameters:
- input_filename (str): Path to the original PDB file.
- output_filename (str): Path to the new PDB file to save updated coordinates.
- coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates.
"""
# Convert the tensor to a list of tuples
new_coordinates = coordinates_tensor.squeeze(0).tolist()
# Create a parser and parse the structure
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure("structure", input_filename)
# Flattened iterator for atoms to update coordinates
atom_iterator = (
atom
for model in structure
for chain in model
for residue in chain
for atom in residue
)
# Update each atom's coordinates
for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates):
original_anisou = atom.get_anisou()
original_uij = atom.get_siguij()
original_tm = atom.get_sigatm()
original_occupancy = atom.get_occupancy()
original_bfactor = atom.get_bfactor()
original_altloc = atom.get_altloc()
original_serial_number = atom.get_serial_number()
original_element = atom.get_charge()
original_parent = atom.get_parent()
original_radius = atom.get_radius()
# Update only the atom coordinates, keep other fields intact
atom.coord = np.array([new_x, new_y, new_z])
# Reapply the preserved properties
atom.set_anisou(original_anisou)
atom.set_siguij(original_uij)
atom.set_sigatm(original_tm)
atom.set_occupancy(original_occupancy)
atom.set_bfactor(original_bfactor)
atom.set_altloc(original_altloc)
# atom.set_fullname(original_fullname)
atom.set_serial_number(original_serial_number)
atom.set_charge(original_element)
atom.set_radius(original_radius)
atom.set_parent(original_parent)
# atom.set_name(original_name)
# atom.set_leve
# Save the updated structure to a new PDB file
io = PDBIO()
io.set_structure(structure)
io.save(output_filename)
# Return the path to the updated PDB file
return output_filename
def merge_pdb_files(file1, file2, output_file):
r"""
Merges two PDB files by concatenating them without altering their contents.
Parameters:
- file1 (str): Path to the first PDB file (e.g., receptor).
- file2 (str): Path to the second PDB file (e.g., ligand).
- output_file (str): Path to the output file where the merged structure will be saved.
"""
with open(output_file, "w") as outfile:
# Copy the contents of the first file
with open(file1, "r") as f1:
lines = f1.readlines()
# Write all lines except the last 'END' line
outfile.writelines(lines[:-1])
# Copy the contents of the second file
with open(file2, "r") as f2:
outfile.write(f2.read())
print(f"Merged PDB saved to {output_file}")
return output_file
def predict(
input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2
):
start_time = time.time()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
data = create_graph(input_protein_1, input_protein_2, k=10, device=device)
logger.info("Created graph data")
model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt")
model = model.to(device)
model.eval()
logger.info("Loaded model")
with torch.no_grad():
receptor_coords, ligand_coords = model(data)
file1 = update_pdb_coordinates_from_tensor(
input_protein_1, "holo_ligand.pdb", ligand_coords
)
file2 = update_pdb_coordinates_from_tensor(
input_protein_2, "holo_receptor.pdb", receptor_coords
)
out_pdb = merge_pdb_files(file1, file2, "output.pdb")
# return an output pdb file with the protein and two chains A and B.
# also return a JSON with any metrics you want to report
metrics = {"mean_plddt": 80, "binding_affinity": 2}
end_time = time.time()
run_time = end_time - start_time
return out_pdb, json.dumps(metrics), run_time
with gr.Blocks() as app:
gr.Markdown("# Template for inference")
gr.Markdown("EquiMPNN MOdel")
with gr.Row():
with gr.Column():
input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)")
input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)")
with gr.Column():
input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)")
input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)")
input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
# define any options here
# for automated inference the default options are used
# slider_option = gr.Slider(0,10, label="Slider Option")
# checkbox_option = gr.Checkbox(label="Checkbox Option")
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
"3v1c_A.pdb",
"GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
"3v1c_B.pdb",
],
],
[input_seq_1, input_protein_1, input_seq_2, input_protein_2],
)
reps = [
{
"model": 0,
"style": "cartoon",
"chain": "A",
"color": "whiteCarbon",
},
{
"model": 0,
"style": "cartoon",
"chain": "B",
"color": "greenCarbon",
},
{
"model": 0,
"chain": "A",
"style": "stick",
"sidechain": True,
"color": "whiteCarbon",
},
{
"model": 0,
"chain": "B",
"style": "stick",
"sidechain": True,
"color": "greenCarbon",
},
]
# outputs
out = Molecule3D(reps=reps)
metrics = gr.JSON(label="Metrics")
run_time = gr.Textbox(label="Runtime")
btn.click(
predict,
inputs=[
input_seq_1,
input_msa_1,
input_protein_1,
input_seq_2,
input_msa_2,
input_protein_2,
],
outputs=[out, metrics, run_time],
)
app.launch()