import gradio as gr import os def load_html(html_file: str): with open(os.path.join("html", html_file), "r") as f: return f.read() def load_protein_from_file(protein_file) -> str: """ Parameters ---------- protein_file: _TemporaryFileWrapper GradIO file object Returns ------- str Protein PDB file content """ with open(protein_file.name, "r") as f: return f.read() def load_ligand_from_file(ligand_file): with open(ligand_file.name, "r") as f: return f.read() def protein_html_from_file(protein_file): protein = load_protein_from_file(protein_file) protein_html = load_html("protein.html") html = protein_html.replace("%%%PDB%%%", protein) wrapper = load_html("wrapper.html") return wrapper.replace("%%%HTML%%%", html) def ligand_html_from_file(ligand_file): ligand = load_ligand_from_file(ligand_file) ligand_html = load_html("ligand.html") html = ligand_html.replace("%%%SDF%%%", ligand) wrapper = load_html("wrapper.html") return wrapper.replace("%%%HTML%%%", html) def protein_ligand_html_from_file(protein_file, ligand_file): protein = load_protein_from_file(protein_file) ligand = load_ligand_from_file(ligand_file) protein_ligand_html = load_html("pl.html") html = protein_ligand_html.replace("%%%PDB%%%", protein) html = html.replace("%%%SDF%%%", ligand) wrapper = load_html("wrapper.html") return wrapper.replace("%%%HTML%%%", html) def predict(protein_file, ligand_file, cnn="default"): import molgrid from gninatorch import gnina, dataloaders import torch import pandas as pd device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model, ensemble = gnina.setup_gnina_model(cnn, 23.5, 0.5) model.eval() model.to(device) example_provider = molgrid.ExampleProvider( data_root="", balanced=False, shuffle=False, default_batch_size=1, iteration_scheme=molgrid.IterationScheme.SmallEpoch, ) with open("data.in", "w") as f: f.write(protein_file.name) f.write(" ") f.write(ligand_file.name) print("Populating example provider... ", end="") example_provider.populate("data.in") print("done") grid_maker = molgrid.GridMaker(resolution=0.5, dimension=23.5) # TODO: Allow average over different rotations loader = dataloaders.GriddedExamplesLoader( example_provider=example_provider, grid_maker=grid_maker, random_translation=0.0, # No random translations for inference random_rotation=False, # No random rotations for inference grids_only=True, device=device, ) print("Loading and gridding data... ", end="") batch = next(loader) print("done") print("Predicting... ", end="") with torch.no_grad(): log_pose, affinity, affinity_var = model(batch) print("done") return pd.DataFrame( { "CNNscore": [torch.exp(log_pose[:, -1]).item()], "CNNaffinity": [affinity.item()], "CNNvariance": [affinity_var.item()], } ) demo = gr.Blocks() with demo: gr.Markdown("# Protein and Ligand") with gr.Row(): with gr.Box(): pfile = gr.File(file_count="single") pbtn = gr.Button("View") protein = gr.HTML() pbtn.click(fn=protein_html_from_file, inputs=[pfile], outputs=protein) with gr.Box(): lfile = gr.File(file_count="single") lbtn = gr.Button("View") ligand = gr.HTML() lbtn.click(fn=ligand_html_from_file, inputs=[lfile], outputs=ligand) gr.Markdown("# Protein-Ligand Complex") with gr.Row(): plcomplex = gr.HTML() # TODO: Automatically display complex when both files are uploaded plbtn = gr.Button("View") plbtn.click( fn=protein_ligand_html_from_file, inputs=[pfile, lfile], outputs=plcomplex ) gr.Markdown("# Gnina-Torch") with gr.Row(): df = gr.Dataframe() btn = gr.Button("Score!") btn.click(fn=predict, inputs=[pfile, lfile], outputs=df) demo.launch()