Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from torchvision.transforms.functional import to_pil_image | |
# pipeline = StableDiffusionPipeline.from_pretrained( | |
# pretrained_model_name_or_path="weights", torch_dtype=torch.float16 | |
# ) | |
# pipeline.to('cuda') | |
concept_to_dummy_tokens_map = torch.load("concept_to_dummy_tokens_map.pt") | |
def replace_concept_tokens(text: str): | |
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items(): | |
text = text.replace(concept_token, dummy_tokens) | |
return text | |
# def inference( | |
# prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0 | |
# ): | |
# prompt = replace_concept_tokens(prompt) | |
# for _ in range(3): | |
# img_list = pipeline( | |
# prompt=prompt, | |
# num_inference_steps=num_inference_steps, | |
# guidance_scale=guidance_scale, | |
# ) | |
# if not img_list["nsfw_content_detected"]: | |
# break | |
# return img_list["sample"] | |
DEFAULT_PROMPT = ( | |
"A watercolor painting on textured paper of a <det-logo> using soft strokes," | |
" pastel colors, incredible composition, masterpiece" | |
) | |
def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int): | |
return [torch.ones(512, 512, 3).numpy() for _ in range(2)] | |
with gr.Blocks() as demo: | |
prompt = gr.Textbox( | |
label="Prompt including the token '<det-logo>'", | |
placeholder=DEFAULT_PROMPT, | |
interactive=True, | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True | |
) | |
num_inference_steps = gr.Slider( | |
minimum=25, | |
maximum=60, | |
value=40, | |
label="Num Inference Steps", | |
interactive=True, | |
step=1, | |
) | |
seed = gr.Slider( | |
minimum=2147483147, | |
maximum=2147483647, | |
value=2147483397, | |
label="Seed", | |
interactive=True, | |
) | |
generate_btn = gr.Button(label="Generate") | |
gallery = gr.Gallery( | |
label="Generated Images", | |
value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)], | |
).style(height="auto") | |
generate_btn.click( | |
white_imgs, | |
inputs=[prompt, guidance_scale, num_inference_steps, seed], | |
outputs=gallery, | |
) | |
demo.launch() | |