paccmann_gp / app.py
jannisborn's picture
update
973616f unverified
raw
history blame
5.17 kB
import logging
import pathlib
from typing import List
import gradio as gr
import pandas as pd
from gt4sd.algorithms.controlled_sampling.paccmann_gp import (
PaccMannGPGenerator,
PaccMannGP,
)
from gt4sd.algorithms.controlled_sampling.paccmann_gp.implementation import (
MINIMIZATION_FUNCTIONS,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MINIMIZATION_FUNCTIONS.pop("callable", None)
MINIMIZATION_FUNCTIONS.pop("molwt", None)
def run_inference(
algorithm_version: str,
targets: List[str],
protein_target: str,
temperature: float,
length: float,
number_of_samples: int,
limit: int,
number_of_steps: int,
number_of_initial_points: int,
number_of_optimization_rounds: int,
sampling_variance: float,
samples_for_evaluation: int,
maximum_number_of_sampling_steps: int,
seed: int,
):
config = PaccMannGPGenerator(
algorithm_version=algorithm_version.split("_")[-1],
batch_size=32,
temperature=temperature,
generated_length=length,
limit=limit,
acquisition_function="EI",
number_of_steps=number_of_steps,
number_of_initial_points=number_of_initial_points,
initial_point_generator="random",
number_of_optimization_rounds=number_of_optimization_rounds,
sampling_variance=sampling_variance,
samples_for_evaluation=samples_for_evaluation,
maximum_number_of_sampling_steps=maximum_number_of_sampling_steps,
seed=seed,
)
target = {i: {} for i in targets}
if "affinity" in targets:
if protein_target == "" or not isinstance(protein_target, str):
raise ValueError(
f"Protein target must be specified for affinity prediction, not ={protein_target}"
)
target["affinity"]["protein"] = protein_target
else:
protein_target = ""
model = PaccMannGP(config, target=target)
samples = list(model.sample(number_of_samples))
return draw_grid_generate(
samples=samples,
n_cols=5,
properties=set(target.keys()),
protein_target=protein_target,
)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_version"]
for x in list(filter(lambda x: "PaccMannGP" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(
metadata_root.joinpath("examples.csv"), header=None, sep="|"
).fillna("")
examples[1] = examples[1].apply(eval)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="PaccMannGP",
inputs=[
gr.Dropdown(algos, label="Algorithm version", value="v0"),
gr.CheckboxGroup(
choices=list(MINIMIZATION_FUNCTIONS.keys()),
value=["qed"],
multiselect=True,
label="Property goals",
),
gr.Textbox(
label="Protein target",
placeholder="MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT",
lines=1,
),
gr.Slider(minimum=0.5, maximum=2, value=1, label="Decoding temperature"),
gr.Slider(
minimum=5,
maximum=400,
value=100,
label="Maximal sequence length",
step=1,
),
gr.Slider(
minimum=1, maximum=50, value=10, label="Number of samples", step=1
),
gr.Slider(minimum=1, maximum=8, value=4.0, label="Limit"),
gr.Slider(minimum=1, maximum=32, value=8, label="Number of steps", step=1),
gr.Slider(
minimum=1, maximum=32, value=4, label="Number of initial points", step=1
),
gr.Slider(
minimum=1,
maximum=4,
value=1,
label="Number of optimization rounds",
step=1,
),
gr.Slider(minimum=0.01, maximum=1, value=0.1, label="Sampling variance"),
gr.Slider(
minimum=1,
maximum=10,
value=1,
label="Samples used for evaluation",
step=1,
),
gr.Slider(
minimum=1,
maximum=64,
value=4,
label="Maximum number of sampling steps",
step=1,
),
gr.Number(value=42, label="Seed", precision=0),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)