Spaces:
Running
Running
File size: 3,018 Bytes
09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a bebd86d 09c907a bebd86d 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a a4eba41 09c907a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.controlled_sampling.advanced_manufacturing import (
CatalystGenerator,
AdvancedManufacturing,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def run_inference(
algorithm_version: str,
target_binding_energy: float,
primer_smiles: str,
length: float,
number_of_points: int,
number_of_steps: int,
number_of_samples: int,
):
config = CatalystGenerator(
algorithm_version=algorithm_version,
number_of_points=number_of_points,
number_of_steps=number_of_steps,
generated_length=length,
primer_smiles=primer_smiles,
)
model = AdvancedManufacturing(config, target=target_binding_energy)
samples = list(model.sample(number_of_samples))
seeds = [] if primer_smiles == "" else [primer_smiles]
return draw_grid_generate(samples=samples, n_cols=5, seeds=seeds)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_version"]
for x in list(filter(lambda x: "Advanced" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="Advanced Manufacturing",
inputs=[
gr.Dropdown(
algos,
label="Algorithm version",
value="NCCR_rnn_suzuki_aug16_smiles",
),
gr.Slider(minimum=1, maximum=100, value=10, label="Target binding energy"),
gr.Textbox(
label="Primer SMILES",
placeholder="FP(F)F.CP(C)c1ccccc1.[Au]",
lines=1,
),
gr.Slider(
minimum=5,
maximum=400,
value=100,
label="Maximal sequence length",
step=1,
),
gr.Slider(
minimum=16, maximum=128, value=32, label="Number of points", step=1
),
gr.Slider(
minimum=16, maximum=128, value=50, label="Number of steps", step=1
),
gr.Slider(
minimum=1, maximum=50, value=10, label="Number of samples", step=1
),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|