File size: 1,653 Bytes
1825dee
 
 
60fae37
85ad568
 
60fae37
375ee1a
573f2cc
 
 
 
 
 
8704b30
76c9d92
f07faaf
573f2cc
 
8704b30
573f2cc
 
8704b30
76c9d92
f07faaf
8704b30
573f2cc
8704b30
573f2cc
 
375ee1a
42f5e17
7d63449
375ee1a
fee06b2
375ee1a
60fae37
 
 
7d63449
 
 
 
 
1f2402d
 
 
7d63449
8198e38
 
7d63449
60fae37
6bb464c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import spaces


import gradio as gr
from gradio_molecule3d import Molecule3D
from gradio_cofoldinginput import CofoldingInput

import os
import urllib.request

CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1.ckpt"

cache = "~/.boltz"
ccd = f"{cache}/ccd.pkl"
if not os.path.exists(ccd):
    print(
        f"Downloading the CCD dictionary to {ccd}. You may "
    )
    urllib.request.urlretrieve(CCD_URL, str(ccd))

# Download model
model =f"{cache}/boltz1.ckpt"
if not os.path.exists(model):
    print(
        f"Downloading the model weights to {model}"
    )
    urllib.request.urlretrieve(MODEL_URL, str(model))



@spaces.GPU(duration=120)
def predict(jobname, inputs, recycling_steps, sampling_steps, diffusion_samples):

    os.system("boltz predict ligand.fasta --output_format pdb")
    return "boltz_results_ligand/predictions/ligand/ligand_model_0.cif"

with gr.Blocks() as blocks:
    gr.Markdown("# Boltz-1")
    with gr.Tab("Main"):
        jobname = gr.Textbox(label="Jobname")
        inp = CofoldingInput(label="Input")
        out = Molecule3D(label="Output")
    with gr.Tab("Settings"):
        recycling_steps =gr.Slider(value=3, minimum=0, label="Recycling steps")
        sampling_steps = gr.Slider(value=200, minimum=0, label="Sampling steps")
        diffusion_samples = gr.Slider(value=1, label="Diffusion samples")

    btn = gr.Button("predict")

    btn.click(fn=predict, inputs=[jobname,inp, recycling_steps, sampling_steps, diffusion_samples], outputs=[out],  api_name="predict")

blocks.launch(ssr_mode=False)