File size: 6,159 Bytes
d4e347d
cf729de
d4e347d
cf729de
 
 
 
 
 
 
 
 
 
d4e347d
cf729de
 
 
 
 
 
 
 
 
d4e347d
cf729de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e347d
 
cf729de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e347d
 
cf729de
 
d4e347d
 
cf729de
 
 
d4e347d
 
cf729de
 
 
 
 
 
d4e347d
 
cf729de
 
 
 
 
 
 
 
d4e347d
cf729de
d4e347d
cf729de
 
 
d4e347d
cf729de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e347d
 
 
 
cf729de
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import gradio as gr
from gradio_molecule3d import Molecule3D
import spaces
import subprocess
import glob


# Directory to store cached outputs
CACHE_DIR = "gradio_cached_examples"


reps =    [
    {
      "model": 0,
      "chain": "",
      "resname": "",
      "style": "stick",
      "color": "whiteCarbon",
      "residue_range": "",
      "around": 0,
      "byres": False,
      "visible": False
    }
  ]
# Ensure the cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)

# Define example files and precomputed outputs
example_fasta_files = [
    f"cache_examples/boltz_0.fasta",
    f"cache_examples/Armadillo_6.fasta",
    f"cache_examples/Covid_3.fasta",
    f"cache_examples/Malaria_2.fasta",
    f"cache_examples/MITOCHONDRIAL_9.fasta",
    f"cache_examples/Monkeypox_4.fasta",
    f"cache_examples/Plasmodium_1.fasta",
    f"cache_examples/PROTOCADHERIN_8.fasta",
    f"cache_examples/Vault_5.fasta",
    f"cache_examples/Zipper_7.fasta",
]

# matching `.pdb` files in the `CACHE_DIR`
example_outputs = [
    os.path.join(CACHE_DIR, os.path.basename(fasta_file).replace(".fasta", ".pdb"))
    for fasta_file in example_fasta_files
]

# must load cached outputs
def load_cached_example_outputs(fasta_file: str) -> str:
    # Find the corresponding `.pdb` file
    pdb_file = os.path.basename(fasta_file).replace(".fasta", ".pdb")
    cached_pdb_path = os.path.join(CACHE_DIR, pdb_file)
    if os.path.exists(cached_pdb_path):
        return cached_pdb_path
    else:
        raise FileNotFoundError(f"Cached output not found for {pdb_file}")

# handle example click
def on_example_click(fasta_file: str) -> str:
    return load_cached_example_outputs(fasta_file)

# run predictions
# @spaces.GPU(duration=120)
def predict(data, out_dir, cache="~/.boltz", checkpoint=None, devices=1,
            accelerator="gpu", recycling_steps=3, sampling_steps=50,
            diffusion_samples=1, output_format="pdb", num_workers=2,
            override=False):

    print("Arguments passed to `predict` function:")
    print(f"  data: {data}")
    print(f"  out_dir: {out_dir}")
    print(f"  cache: {cache}")
    print(f"  checkpoint: {checkpoint}")
    print(f"  devices: {devices}")
    print(f"  accelerator: {accelerator}")
    print(f"  recycling_steps: {recycling_steps}")
    print(f"  sampling_steps: {sampling_steps}")
    print(f"  diffusion_samples: {diffusion_samples}")
    print(f"  output_format: {output_format}")
    print(f"  num_workers: {num_workers}")
    print(f"  override: {override}")
    # Construct the base command
    command = [
        "boltz", "predict",
        "--out_dir", out_dir,
        "--cache", cache,
        "--devices", str(devices),
        "--accelerator", accelerator,
        "--recycling_steps", str(recycling_steps),
        "--sampling_steps", str(sampling_steps),
        "--diffusion_samples", str(diffusion_samples),
        "--output_format", output_format,
        "--num_workers", str(num_workers)
    ]

    
    # Add optional arguments if provided
    if checkpoint:
        command.extend(["--checkpoint", checkpoint])
    if override:
        command.append("--override")
    

    # Add the data argument (path to the input file)
    command.append(data)

    # print("Constructed subprocess command:")
    # print("  " + " ".join(command))


    # Run the command using subprocess
    result = subprocess.run(command, capture_output=True, text=True)
    if result.returncode == 0:
        print("Prediction completed successfully...!")
        print(f"Output saved to: {out_dir}")
    else:
        print("Prediction failed :(")
        print("Error:", result.stderr)

# @spaces.GPU(duration=60)
def run_prediction(input_file, cache, accelerator, sampling_steps,
                   diffusion_samples, output_format, checkpoint="./ckpt/boltz1.ckpt"):
    # Assuming `input_file` is a path to the .fasta file
    data = input_file.name  # Path to the uploaded .fasta file
    print("the data : ", data)
    # Update the `reps` settings with the chosen style and color

    # Call your predict function
    predict(
        data=data,
        out_dir="./",
        cache=cache,
        accelerator=accelerator,
        sampling_steps=sampling_steps,
        diffusion_samples=diffusion_samples,
        output_format=output_format,
        checkpoint=checkpoint
    )

    # Search for the latest .pdb file in the predictions folder
    search_path = os.path.join(out_dir, "boltz_results*/predictions/**/*.pdb")
    pdb_files = glob.glob(search_path, recursive=True)  # Enable recursive search
    
    if not pdb_files:
        print("No .pdb files found in the predictions folder.")
        return None

    # Get the latest .pdb file based on modification time
    latest_pdb_file = max(pdb_files, key=os.path.getmtime)

    
    # Return the latest PDB file path
    return latest_pdb_file


with gr.Blocks() as demo:
    gr.Markdown("# 🔬 Boltz-1: Democratizing Biomolecular Interaction Modeling 🧬")

    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"])

            with gr.Accordion("Advanced Settings", open=False):
                accelerator = gr.Radio(choices=["gpu", "cpu"], value="gpu", label="Accelerator")
                sampling_steps = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Sampling Steps")
                diffusion_samples = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Diffusion Samples")
            btn = gr.Button("Predict")
            
        with gr.Column(scale=3):
            out = Molecule3D(label="Generated Molecule", reps=reps)

        

    btn.click(
        run_prediction,
        inputs=[inp, accelerator, sampling_steps, diffusion_samples],
        outputs=out
    )
    gr.Examples(
                examples=[[fasta_file] for fasta_file in example_fasta_files],
                inputs=[inp],
                outputs=out, 
                fn=lambda fasta_file: on_example_click(fasta_file),
                cache_examples=True
            )

    

if __name__ == "__main__":
    demo.launch(share=True, debug=True)