|
import logging |
|
import pathlib |
|
import pickle |
|
import gradio as gr |
|
from typing import Dict, Any |
|
import pandas as pd |
|
from gt4sd.algorithms.generation.diffusion import ( |
|
DiffusersGenerationAlgorithm, |
|
GeoDiffGenerator, |
|
) |
|
from utils import draw_grid_generate |
|
from rdkit import Chem |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
|
|
def run_inference(prompt_file: str, prompt_id: int, number_of_samples: int): |
|
|
|
|
|
with open(prompt_file.name, "rb") as f: |
|
prompts = pickle.load(f) |
|
|
|
if all(isinstance(x, int) for x in prompts.keys()): |
|
prompt = prompts[prompt_id] |
|
else: |
|
prompt = prompts |
|
|
|
config = GeoDiffGenerator(prompt=prompt) |
|
model = DiffusersGenerationAlgorithm(config) |
|
results = list(model.sample(number_of_samples)) |
|
smiles = [Chem.MolToSmiles(m) for m in results] |
|
|
|
return draw_grid_generate(samples=smiles, n_cols=5) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = [ |
|
[str(metadata_root.joinpath("mol_dct.pkl")), 0, 2], |
|
[str(metadata_root.joinpath("mol_dct.pkl")), 1, 2], |
|
] |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="GeoDiff", |
|
inputs=[ |
|
gr.File(file_types=[".pkl"], label="GeoDiff prompt"), |
|
gr.Number(value=0, label="Prompt ID", precision=0), |
|
gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1), |
|
], |
|
outputs=gr.HTML(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples, |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|