Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from gradio_molecule3d import Molecule3D | |
import spaces | |
import subprocess | |
import glob | |
# Directory to store cached outputs | |
CACHE_DIR = "gradio_cached_examples" | |
reps = [ | |
{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "stick", | |
"color": "whiteCarbon", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": False | |
} | |
] | |
# Ensure the cache directory exists | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Define example files and precomputed outputs | |
example_fasta_files = [ | |
f"cache_examples/boltz_0.fasta", | |
f"cache_examples/Armadillo_6.fasta", | |
f"cache_examples/Covid_3.fasta", | |
f"cache_examples/Malaria_2.fasta", | |
f"cache_examples/MITOCHONDRIAL_9.fasta", | |
f"cache_examples/Monkeypox_4.fasta", | |
f"cache_examples/Plasmodium_1.fasta", | |
f"cache_examples/PROTOCADHERIN_8.fasta", | |
f"cache_examples/Vault_5.fasta", | |
f"cache_examples/Zipper_7.fasta", | |
] | |
# matching `.pdb` files in the `CACHE_DIR` | |
example_outputs = [ | |
os.path.join(CACHE_DIR, os.path.basename(fasta_file).replace(".fasta", ".pdb")) | |
for fasta_file in example_fasta_files | |
] | |
# must load cached outputs | |
def load_cached_example_outputs(fasta_file: str) -> str: | |
# Find the corresponding `.pdb` file | |
pdb_file = os.path.basename(fasta_file).replace(".fasta", ".pdb") | |
cached_pdb_path = os.path.join(CACHE_DIR, pdb_file) | |
if os.path.exists(cached_pdb_path): | |
return cached_pdb_path | |
else: | |
raise FileNotFoundError(f"Cached output not found for {pdb_file}") | |
# handle example click | |
def on_example_click(fasta_file: str) -> str: | |
return load_cached_example_outputs(fasta_file) | |
# run predictions | |
# @spaces.GPU(duration=120) | |
def predict(data, out_dir, cache="~/.boltz", checkpoint=None, devices=1, | |
accelerator="gpu", recycling_steps=3, sampling_steps=50, | |
diffusion_samples=1, output_format="pdb", num_workers=2, | |
override=False): | |
print("Arguments passed to `predict` function:") | |
print(f" data: {data}") | |
print(f" out_dir: {out_dir}") | |
print(f" cache: {cache}") | |
print(f" checkpoint: {checkpoint}") | |
print(f" devices: {devices}") | |
print(f" accelerator: {accelerator}") | |
print(f" recycling_steps: {recycling_steps}") | |
print(f" sampling_steps: {sampling_steps}") | |
print(f" diffusion_samples: {diffusion_samples}") | |
print(f" output_format: {output_format}") | |
print(f" num_workers: {num_workers}") | |
print(f" override: {override}") | |
# Construct the base command | |
command = [ | |
"boltz", "predict", | |
"--out_dir", out_dir, | |
"--cache", cache, | |
"--devices", str(devices), | |
"--accelerator", accelerator, | |
"--recycling_steps", str(recycling_steps), | |
"--sampling_steps", str(sampling_steps), | |
"--diffusion_samples", str(diffusion_samples), | |
"--output_format", output_format, | |
"--num_workers", str(num_workers) | |
] | |
# Add optional arguments if provided | |
if checkpoint: | |
command.extend(["--checkpoint", checkpoint]) | |
if override: | |
command.append("--override") | |
# Add the data argument (path to the input file) | |
command.append(data) | |
# print("Constructed subprocess command:") | |
# print(" " + " ".join(command)) | |
# Run the command using subprocess | |
result = subprocess.run(command, capture_output=True, text=True) | |
if result.returncode == 0: | |
print("Prediction completed successfully...!") | |
print(f"Output saved to: {out_dir}") | |
else: | |
print("Prediction failed :(") | |
print("Error:", result.stderr) | |
# @spaces.GPU(duration=60) | |
def run_prediction(input_file, cache, accelerator, sampling_steps, | |
diffusion_samples, output_format, checkpoint="./ckpt/boltz1.ckpt"): | |
# Assuming `input_file` is a path to the .fasta file | |
data = input_file.name # Path to the uploaded .fasta file | |
print("the data : ", data) | |
# Update the `reps` settings with the chosen style and color | |
# Call your predict function | |
predict( | |
data=data, | |
out_dir="./", | |
cache=cache, | |
accelerator=accelerator, | |
sampling_steps=sampling_steps, | |
diffusion_samples=diffusion_samples, | |
output_format=output_format, | |
checkpoint=checkpoint | |
) | |
# Search for the latest .pdb file in the predictions folder | |
search_path = os.path.join(out_dir, "boltz_results*/predictions/**/*.pdb") | |
pdb_files = glob.glob(search_path, recursive=True) # Enable recursive search | |
if not pdb_files: | |
print("No .pdb files found in the predictions folder.") | |
return None | |
# Get the latest .pdb file based on modification time | |
latest_pdb_file = max(pdb_files, key=os.path.getmtime) | |
# Return the latest PDB file path | |
return latest_pdb_file | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🔬 Boltz-1: Democratizing Biomolecular Interaction Modeling 🧬") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"]) | |
with gr.Accordion("Advanced Settings", open=False): | |
accelerator = gr.Radio(choices=["gpu", "cpu"], value="gpu", label="Accelerator") | |
sampling_steps = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Sampling Steps") | |
diffusion_samples = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Diffusion Samples") | |
btn = gr.Button("Predict") | |
with gr.Column(scale=3): | |
out = Molecule3D(label="Generated Molecule", reps=reps) | |
btn.click( | |
run_prediction, | |
inputs=[inp, accelerator, sampling_steps, diffusion_samples], | |
outputs=out | |
) | |
gr.Examples( | |
examples=[[fasta_file] for fasta_file in example_fasta_files], | |
inputs=[inp], | |
outputs=out, | |
fn=lambda fasta_file: on_example_click(fasta_file), | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, debug=True) | |