|
import logging |
|
import pathlib |
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.generation.diffusion import ( |
|
DiffusersGenerationAlgorithm, |
|
DDPMGenerator, |
|
DDIMGenerator, |
|
ScoreSdeGenerator, |
|
LDMTextToImageGenerator, |
|
LDMGenerator, |
|
StableDiffusionGenerator, |
|
) |
|
from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
|
|
def run_inference(model_type: str, prompt: str): |
|
|
|
if prompt == "": |
|
config = eval(f"{model_type}()") |
|
else: |
|
config = eval(f"{model_type}(prompt={prompt})") |
|
if config.modality != "token2image" and prompt != "": |
|
raise ValueError( |
|
f"{model_type} is an unconditional generative model, please remove prompt (not={prompt})" |
|
) |
|
model = DiffusersGenerationAlgorithm(config) |
|
image = list(model.sample(1))[0] |
|
|
|
return image |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
all_algos = ApplicationsRegistry.list_available() |
|
algos = [ |
|
x["algorithm_application"] |
|
for x in list(filter(lambda x: "Diff" in x["algorithm_name"], all_algos)) |
|
] |
|
algos = [a for a in algos if not "GeoDiff" in a] |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
|
"" |
|
) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="Diffusion-based image generators", |
|
inputs=[ |
|
gr.Dropdown( |
|
algos, label="Diffusion model", value="StableDiffusionGenerator" |
|
), |
|
gr.Textbox(label="Text prompt", placeholder="A blue tree", lines=1), |
|
], |
|
outputs=gr.outputs.Image(type="pil"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|