import logging import pathlib import pickle import gradio as gr from typing import Dict, Any import pandas as pd from gt4sd.algorithms.generation.diffusion import ( DiffusersGenerationAlgorithm, GeoDiffGenerator, ) from utils import draw_grid_generate from rdkit import Chem logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def run_inference(prompt_file: str, prompt_id: int, number_of_samples: int): # Read file: with open(prompt_file.name, "rb") as f: prompts = pickle.load(f) if all(isinstance(x, int) for x in prompts.keys()): prompt = prompts[prompt_id] else: prompt = prompts config = GeoDiffGenerator(prompt=prompt) model = DiffusersGenerationAlgorithm(config) results = list(model.sample(number_of_samples)) smiles = [Chem.MolToSmiles(m) for m in results] return draw_grid_generate(samples=smiles, n_cols=5) if __name__ == "__main__": # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = [ [metadata_root.joinpath("mol_dct.pkl"), 0, 2], [metadata_root.joinpath("mol_dct.pkl"), 1, 2], ] with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="GeoDiff", inputs=[ gr.File(file_types=[".pkl"], label="GeoDiff prompt"), gr.Number(value=0, label="Prompt ID", precision=0), gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1), ], outputs=gr.HTML(label="Output"), article=article, description=description, examples=examples, ) demo.launch(debug=True, show_error=True)