detsd_demo / app.py
Garrett Goon
testing basics
384ec64
raw
history blame
2.32 kB
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from torchvision.transforms.functional import to_pil_image
# pipeline = StableDiffusionPipeline.from_pretrained(
# pretrained_model_name_or_path="weights", torch_dtype=torch.float16
# )
# pipeline.to('cuda')
concept_to_dummy_tokens_map = torch.load("concept_to_dummy_tokens_map.pt")
def replace_concept_tokens(text: str):
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
text = text.replace(concept_token, dummy_tokens)
return text
# def inference(
# prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0
# ):
# prompt = replace_concept_tokens(prompt)
# for _ in range(3):
# img_list = pipeline(
# prompt=prompt,
# num_inference_steps=num_inference_steps,
# guidance_scale=guidance_scale,
# )
# if not img_list["nsfw_content_detected"]:
# break
# return img_list["sample"]
DEFAULT_PROMPT = (
"A watercolor painting on textured paper of a <det-logo> using soft strokes,"
" pastel colors, incredible composition, masterpiece"
)
def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int):
return [torch.ones(512, 512, 3).numpy() for _ in range(2)]
with gr.Blocks() as demo:
prompt = gr.Textbox(
label="Prompt including the token '<det-logo>'",
placeholder=DEFAULT_PROMPT,
interactive=True,
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True
)
num_inference_steps = gr.Slider(
minimum=25,
maximum=60,
value=40,
label="Num Inference Steps",
interactive=True,
step=1,
)
seed = gr.Slider(
minimum=2147483147,
maximum=2147483647,
value=2147483397,
label="Seed",
interactive=True,
)
generate_btn = gr.Button(label="Generate")
gallery = gr.Gallery(
label="Generated Images",
value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)],
).style(height="auto")
generate_btn.click(
white_imgs,
inputs=[prompt, guidance_scale, num_inference_steps, seed],
outputs=gallery,
)
demo.launch()