moses / app.py
jannisborn's picture
update
8ae3405 unverified
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.conditional_generation.guacamol import (
AaeGenerator,
GraphGAGenerator,
GraphMCTSGenerator,
GuacaMolGenerator,
MosesGenerator,
OrganGenerator,
VaeGenerator,
SMILESGAGenerator,
SMILESLSTMHCGenerator,
SMILESLSTMPPOGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
TITLE = "GuacaMol & MOSES"
CONFIG_FACTORY = {
"Moses - AaeGenerator": AaeGenerator,
"Moses - VaeGenerator": VaeGenerator,
"Moses - OrganGenerator": OrganGenerator,
"GuacaMol - GraphGAGenerator": GraphGAGenerator,
"GuacaMol - GraphMCTSGenerator": GraphMCTSGenerator,
"GuacaMol - SMILESLSTMHCGenerator": SMILESLSTMHCGenerator,
"GuacaMol - SMILESLSTMPPOGenerator": SMILESLSTMPPOGenerator,
"GuacaMol - SMILESGAGenerator": SMILESGAGenerator,
}
# OVERWRITE
CONFIG_FACTORY = {
"AaeGenerator": AaeGenerator,
"VaeGenerator": VaeGenerator,
"OrganGenerator": OrganGenerator,
}
MODEL_FACTORY = {"Moses": MosesGenerator, "GuacaMol": GuacaMolGenerator}
def run_inference(
algorithm_version: str,
length: int,
# population_size: int,
# random_start: bool,
# patience: int,
# generations: int,
number_of_samples: int,
):
config_class = CONFIG_FACTORY[algorithm_version]
# family = algorithm_version.split(" - ")[0]
family = "Moses"
model_class = MODEL_FACTORY[family]
if family == "Moses":
kwargs = {"n_samples": number_of_samples, "max_len": length}
elif family == "GuacaMol":
kwargs = {
"population_size": population_size,
"random_start": random_start,
"patience": patience,
"generations": generations,
}
if "MCTS" in algorithm_version:
kwargs.pop("random_start")
if "LSTMHC" in algorithm_version:
kwargs["max_len"] = length
kwargs.pop("population_size")
kwargs.pop("patience")
kwargs.pop("generations")
if "LSTMPPO" in algorithm_version:
kwargs = {}
else:
raise ValueError(f"Unknown family {family}")
config = config_class(**kwargs)
model = model_class(configuration=config, target={})
samples = list(model.sample(number_of_samples))
return draw_grid_generate(seeds=[], samples=samples, n_cols=5)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
guacamol_algos = [
"GuacaMol - " + x["algorithm_application"]
for x in list(filter(lambda x: "GuacaMol" in x["algorithm_name"], all_algos))
]
moses_algos = [
"Moses - " + x["algorithm_application"]
for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos))
]
algos = guacamol_algos + moses_algos
# Overwrite to have only Moses
algos = [
x["algorithm_application"]
for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="MOSES",
inputs=[
gr.Dropdown(algos, label="Algorithm version", value="AaeGenerator"),
gr.Slider(
minimum=5, maximum=500, value=100, label="Sequence length", step=1
),
# gr.Slider(
# minimum=5, maximum=500, value=100, label="Population size", step=1
# ),
# gr.Radio(choices=[True, False], label="Random start", value=False),
# gr.Slider(minimum=1, maximum=10, value=4, label="Patience"),
# gr.Slider(minimum=1, maximum=10, value=2, label="Generations"),
gr.Slider(
minimum=1, maximum=50, value=5, label="Number of samples", step=1
),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)