geodiff / app.py
jannisborn's picture
update
d2d894e unverified
raw
history blame
1.93 kB
import logging
import pathlib
import pickle
import gradio as gr
from typing import Dict, Any
import pandas as pd
from gt4sd.algorithms.generation.diffusion import (
DiffusersGenerationAlgorithm,
GeoDiffGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
from rdkit import Chem
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def run_inference(prompt_file: str, prompt_id: int, number_of_samples: int):
# Read file:
with open(prompt_file.name, "rb") as f:
prompts = pickle.load(f)
if all(isinstance(x, int) for x in prompts.keys()):
prompt = prompts[prompt_id]
else:
prompt = prompts
config = GeoDiffGenerator(prompt=prompt)
model = DiffusersGenerationAlgorithm(config)
results = list(model.sample(number_of_samples))
smiles = [Chem.MolToSmiles(m) for m in results]
return draw_grid_generate(samples=smiles, n_cols=5)
if __name__ == "__main__":
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = [
[metadata_root.joinpath("mol_dct.pkl"), 0, 2],
[metadata_root.joinpath("mol_dct.pkl"), 1, 2],
]
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="GeoDiff",
inputs=[
gr.File(file_types=[".pkl"], label="GeoDiff prompt"),
gr.Number(value=0, label="Prompt ID", precision=0),
gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
)
demo.launch(debug=True, show_error=True)