File size: 2,086 Bytes
997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a fbaddc2 997984a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.diffusion import (
DiffusersGenerationAlgorithm,
DDPMGenerator,
DDIMGenerator,
ScoreSdeGenerator,
LDMTextToImageGenerator,
LDMGenerator,
StableDiffusionGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def run_inference(model_type: str, prompt: str):
config = eval(f"{model_type}(prompt={prompt})")
if config.modality != "token2image" and prompt != "":
raise ValueError(
f"{model_type} is an unconditional generative model, please remove prompt (not={prompt})"
)
model = DiffusersGenerationAlgorithm(config)
image = list(model.sample(1))[0]
return image
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_application"]
for x in list(filter(lambda x: "Diff" in x["algorithm_name"], all_algos))
]
algos = [a for a in algos if not "GeoDiff" in a]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="Diffusion-based image generators",
inputs=[
gr.Dropdown(
algos, label="Diffusion model", value="StableDiffusionGenerator"
),
gr.Textbox(label="Text prompt", placeholder="A blue tree", lines=1),
],
outputs=gr.outputs.Image(type="pil"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|