import os def load_file(fpath: str) -> str: """ Load file content. Parameters ---------- fpath: str File path Returns ------- str File content """ with open(fpath, "r") as f: return f.read() def load_html(html_file: str) -> str: return load_file(os.path.join("html", html_file)) def load_md(md_file: str) -> str: return load_file(os.path.join("md", md_file)) def load_protein_from_file(protein_file) -> str: """ Parameters ---------- protein_file: _TemporaryFileWrapper GradIO file object Returns ------- str Protein PDB file content """ with open(protein_file.name, "r") as f: return f.read() def load_ligand_from_file(ligand_file) -> str: """ Load ligand from file. Parameters ---------- ligand_file: _TemporaryFileWrapper GradIO file object Returns ------- str Ligand SDF file content """ with open(ligand_file.name, "r") as f: return f.read() def protein_html_from_file(protein_file) -> str: """ Wrap 3Dmol.js code around protein PDB file. Parameters ---------- protein_file: _TemporaryFileWrapper GradIO file object Returns ------- str 3Dmol.js HTML code for displaying a PDB file """ protein = load_protein_from_file(protein_file) protein_html = load_html("protein.html") html = protein_html.replace("%%%PDB%%%", protein) wrapper = load_html("wrapper.html") return wrapper.replace("%%%HTML%%%", html) def ligand_html_from_file(ligand_file) -> str: """ Wrap 3Dmol.js code around ligand SDF file. Parameters ---------- ligand_file: _TemporaryFileWrapper GradIO file object Returns ------- str 3Dmol.js HTML code for displaying a SDF file """ ligand = load_ligand_from_file(ligand_file) ligand_html = load_html("ligand.html") html = ligand_html.replace("%%%SDF%%%", ligand) wrapper = load_html("wrapper.html") return wrapper.replace("%%%HTML%%%", html) def protein_ligand_html_from_file(protein_file, ligand_file): protein = load_protein_from_file(protein_file) ligand = load_ligand_from_file(ligand_file) protein_ligand_html = load_html("pl.html") html = protein_ligand_html.replace("%%%PDB%%%", protein) html = html.replace("%%%SDF%%%", ligand) wrapper = load_html("wrapper.html") return wrapper.replace("%%%HTML%%%", html) def predict(protein_file, ligand_file, cnn: str = "default"): """ Run gnina-torch on protein-ligand complex. Parameters ---------- protein_file: _TemporaryFileWrapper GradIO file object ligand_file: _TemporaryFileWrapper GradIO file object cnn: str CNN model to use Returns ------- dict[str, float] CNNscore, CNNaffinity, and CNNvariance """ import molgrid from gninatorch import gnina, dataloaders import torch import pandas as pd device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model, ensemble = gnina.setup_gnina_model(cnn, 23.5, 0.5) model.eval() model.to(device) example_provider = molgrid.ExampleProvider( data_root="", balanced=False, shuffle=False, default_batch_size=1, iteration_scheme=molgrid.IterationScheme.SmallEpoch, ) # FIXME: Do this properly... =( [Might require light gnina-torch refactoring] with open("data.in", "w") as f: f.write(protein_file.name) f.write(" ") f.write(ligand_file.name) print("Populating example provider... ", end="") example_provider.populate("data.in") print("done") grid_maker = molgrid.GridMaker(resolution=0.5, dimension=23.5) # TODO: Allow average over different rotations loader = dataloaders.GriddedExamplesLoader( example_provider=example_provider, grid_maker=grid_maker, random_translation=0.0, # No random translations for inference random_rotation=False, # No random rotations for inference grids_only=True, device=device, ) print("Loading and gridding data... ", end="") batch = next(loader) print("done") print("Predicting... ", end="") with torch.no_grad(): log_pose, affinity, affinity_var = model(batch) print("done") return pd.DataFrame( { "CNNscore": [torch.exp(log_pose[:, -1]).item()], "CNNaffinity": [affinity.item()], "CNNvariance": [affinity_var.item()], } ).round(6) if __name__ == "__main__": import gradio as gr demo = gr.Blocks() with demo: gr.Markdown(load_md("intro.md")) gr.Markdown(load_md("input.md")) with gr.Row(): with gr.Box(): pfile = gr.File(file_count="single", label="Protein file (PDB)") gr.Examples(["mols/1cbr_protein.pdb"], inputs=pfile) pbtn = gr.Button("View Protein") pbtn.click(fn=protein_html_from_file, inputs=[pfile], outputs=gr.HTML()) with gr.Box(): lfile = gr.File(file_count="single", label="Ligand file (SDF)") gr.Examples(["mols/1cbr_ligand.sdf"], inputs=lfile) lbtn = gr.Button("View Ligand") lbtn.click(fn=ligand_html_from_file, inputs=[lfile], outputs=gr.HTML()) with gr.Box(): with gr.Column(): # TODO: Automatically display complex when both files are uploaded plbtn = gr.Button("View Protein-Ligand Complex") plbtn.click( fn=protein_ligand_html_from_file, inputs=[pfile, lfile], outputs=gr.HTML(), ) gr.Markdown(load_md("scoring.md")) with gr.Row(): df = gr.Dataframe() with gr.Column(): dd = gr.Dropdown( choices=[ "default", "redock_default2018_ensemble", "general_default2018_ensemble", "crossdock_default2018_ensemble", ], value="default", label="CNN model", ) with gr.Row(): btn = gr.Button("Score!") btn.click(fn=predict, inputs=[pfile, lfile, dd], outputs=df) gr.Markdown( load_md("acknowledgements.md"), ) gr.Markdown(load_md("references.md")) demo.launch()