File size: 1,796 Bytes
1825dee
 
 
60fae37
85ad568
 
60fae37
375ee1a
573f2cc
 
 
 
 
 
 
 
f07faaf
573f2cc
 
 
 
 
 
 
 
f07faaf
573f2cc
 
 
 
 
 
375ee1a
42f5e17
7d63449
375ee1a
fee06b2
375ee1a
60fae37
 
 
7d63449
 
 
 
 
1f2402d
 
 
7d63449
8198e38
 
7d63449
60fae37
6bb464c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import spaces


import gradio as gr
from gradio_molecule3d import Molecule3D
from gradio_cofoldinginput import CofoldingInput

import os
import urllib.request

CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1.ckpt"

cache = "~/.boltz"
ccd = cache / "ccd.pkl"
if not ccd.exists():
    print(
        f"Downloading the CCD dictionary to {ccd}. You may "
        "change the cache directory with the --cache flag."
    )
    urllib.request.urlretrieve(CCD_URL, str(ccd))  # noqa: S310

# Download model
model = cache / "boltz1.ckpt"
if not model.exists():
    print(
        f"Downloading the model weights to {model}. You may "
        "change the cache directory with the --cache flag."
    )
    urllib.request.urlretrieve(MODEL_URL, str(model))  # noqa: S310



@spaces.GPU(duration=120)
def predict(jobname, inputs, recycling_steps, sampling_steps, diffusion_samples):

    os.system("boltz predict ligand.fasta --output_format pdb")
    return "boltz_results_ligand/predictions/ligand/ligand_model_0.cif"

with gr.Blocks() as blocks:
    gr.Markdown("# Boltz-1")
    with gr.Tab("Main"):
        jobname = gr.Textbox(label="Jobname")
        inp = CofoldingInput(label="Input")
        out = Molecule3D(label="Output")
    with gr.Tab("Settings"):
        recycling_steps =gr.Slider(value=3, minimum=0, label="Recycling steps")
        sampling_steps = gr.Slider(value=200, minimum=0, label="Sampling steps")
        diffusion_samples = gr.Slider(value=1, label="Diffusion samples")

    btn = gr.Button("predict")

    btn.click(fn=predict, inputs=[jobname,inp, recycling_steps, sampling_steps, diffusion_samples], outputs=[out],  api_name="predict")

blocks.launch(ssr_mode=False)