boltz / app.py
jadechoghari's picture
add changes
cf729de
raw
history blame
6.16 kB
import os
import gradio as gr
from gradio_molecule3d import Molecule3D
import spaces
import subprocess
import glob
# Directory to store cached outputs
CACHE_DIR = "gradio_cached_examples"
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": False
}
]
# Ensure the cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)
# Define example files and precomputed outputs
example_fasta_files = [
f"cache_examples/boltz_0.fasta",
f"cache_examples/Armadillo_6.fasta",
f"cache_examples/Covid_3.fasta",
f"cache_examples/Malaria_2.fasta",
f"cache_examples/MITOCHONDRIAL_9.fasta",
f"cache_examples/Monkeypox_4.fasta",
f"cache_examples/Plasmodium_1.fasta",
f"cache_examples/PROTOCADHERIN_8.fasta",
f"cache_examples/Vault_5.fasta",
f"cache_examples/Zipper_7.fasta",
]
# matching `.pdb` files in the `CACHE_DIR`
example_outputs = [
os.path.join(CACHE_DIR, os.path.basename(fasta_file).replace(".fasta", ".pdb"))
for fasta_file in example_fasta_files
]
# must load cached outputs
def load_cached_example_outputs(fasta_file: str) -> str:
# Find the corresponding `.pdb` file
pdb_file = os.path.basename(fasta_file).replace(".fasta", ".pdb")
cached_pdb_path = os.path.join(CACHE_DIR, pdb_file)
if os.path.exists(cached_pdb_path):
return cached_pdb_path
else:
raise FileNotFoundError(f"Cached output not found for {pdb_file}")
# handle example click
def on_example_click(fasta_file: str) -> str:
return load_cached_example_outputs(fasta_file)
# run predictions
# @spaces.GPU(duration=120)
def predict(data, out_dir, cache="~/.boltz", checkpoint=None, devices=1,
accelerator="gpu", recycling_steps=3, sampling_steps=50,
diffusion_samples=1, output_format="pdb", num_workers=2,
override=False):
print("Arguments passed to `predict` function:")
print(f" data: {data}")
print(f" out_dir: {out_dir}")
print(f" cache: {cache}")
print(f" checkpoint: {checkpoint}")
print(f" devices: {devices}")
print(f" accelerator: {accelerator}")
print(f" recycling_steps: {recycling_steps}")
print(f" sampling_steps: {sampling_steps}")
print(f" diffusion_samples: {diffusion_samples}")
print(f" output_format: {output_format}")
print(f" num_workers: {num_workers}")
print(f" override: {override}")
# Construct the base command
command = [
"boltz", "predict",
"--out_dir", out_dir,
"--cache", cache,
"--devices", str(devices),
"--accelerator", accelerator,
"--recycling_steps", str(recycling_steps),
"--sampling_steps", str(sampling_steps),
"--diffusion_samples", str(diffusion_samples),
"--output_format", output_format,
"--num_workers", str(num_workers)
]
# Add optional arguments if provided
if checkpoint:
command.extend(["--checkpoint", checkpoint])
if override:
command.append("--override")
# Add the data argument (path to the input file)
command.append(data)
# print("Constructed subprocess command:")
# print(" " + " ".join(command))
# Run the command using subprocess
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode == 0:
print("Prediction completed successfully...!")
print(f"Output saved to: {out_dir}")
else:
print("Prediction failed :(")
print("Error:", result.stderr)
# @spaces.GPU(duration=60)
def run_prediction(input_file, cache, accelerator, sampling_steps,
diffusion_samples, output_format, checkpoint="./ckpt/boltz1.ckpt"):
# Assuming `input_file` is a path to the .fasta file
data = input_file.name # Path to the uploaded .fasta file
print("the data : ", data)
# Update the `reps` settings with the chosen style and color
# Call your predict function
predict(
data=data,
out_dir="./",
cache=cache,
accelerator=accelerator,
sampling_steps=sampling_steps,
diffusion_samples=diffusion_samples,
output_format=output_format,
checkpoint=checkpoint
)
# Search for the latest .pdb file in the predictions folder
search_path = os.path.join(out_dir, "boltz_results*/predictions/**/*.pdb")
pdb_files = glob.glob(search_path, recursive=True) # Enable recursive search
if not pdb_files:
print("No .pdb files found in the predictions folder.")
return None
# Get the latest .pdb file based on modification time
latest_pdb_file = max(pdb_files, key=os.path.getmtime)
# Return the latest PDB file path
return latest_pdb_file
with gr.Blocks() as demo:
gr.Markdown("# 🔬 Boltz-1: Democratizing Biomolecular Interaction Modeling 🧬")
with gr.Row():
with gr.Column(scale=1):
inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"])
with gr.Accordion("Advanced Settings", open=False):
accelerator = gr.Radio(choices=["gpu", "cpu"], value="gpu", label="Accelerator")
sampling_steps = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Sampling Steps")
diffusion_samples = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Diffusion Samples")
btn = gr.Button("Predict")
with gr.Column(scale=3):
out = Molecule3D(label="Generated Molecule", reps=reps)
btn.click(
run_prediction,
inputs=[inp, accelerator, sampling_steps, diffusion_samples],
outputs=out
)
gr.Examples(
examples=[[fasta_file] for fasta_file in example_fasta_files],
inputs=[inp],
outputs=out,
fn=lambda fasta_file: on_example_click(fasta_file),
cache_examples=True
)
if __name__ == "__main__":
demo.launch(share=True, debug=True)