Spaces:
Runtime error
Runtime error
File size: 2,324 Bytes
4cd9bad 384ec64 4cd9bad 384ec64 4cd9bad 384ec64 4cd9bad 384ec64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from torchvision.transforms.functional import to_pil_image
# pipeline = StableDiffusionPipeline.from_pretrained(
# pretrained_model_name_or_path="weights", torch_dtype=torch.float16
# )
# pipeline.to('cuda')
concept_to_dummy_tokens_map = torch.load("concept_to_dummy_tokens_map.pt")
def replace_concept_tokens(text: str):
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
text = text.replace(concept_token, dummy_tokens)
return text
# def inference(
# prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0
# ):
# prompt = replace_concept_tokens(prompt)
# for _ in range(3):
# img_list = pipeline(
# prompt=prompt,
# num_inference_steps=num_inference_steps,
# guidance_scale=guidance_scale,
# )
# if not img_list["nsfw_content_detected"]:
# break
# return img_list["sample"]
DEFAULT_PROMPT = (
"A watercolor painting on textured paper of a <det-logo> using soft strokes,"
" pastel colors, incredible composition, masterpiece"
)
def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int):
return [torch.ones(512, 512, 3).numpy() for _ in range(2)]
with gr.Blocks() as demo:
prompt = gr.Textbox(
label="Prompt including the token '<det-logo>'",
placeholder=DEFAULT_PROMPT,
interactive=True,
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True
)
num_inference_steps = gr.Slider(
minimum=25,
maximum=60,
value=40,
label="Num Inference Steps",
interactive=True,
step=1,
)
seed = gr.Slider(
minimum=2147483147,
maximum=2147483647,
value=2147483397,
label="Seed",
interactive=True,
)
generate_btn = gr.Button(label="Generate")
gallery = gr.Gallery(
label="Generated Images",
value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)],
).style(height="auto")
generate_btn.click(
white_imgs,
inputs=[prompt, guidance_scale, num_inference_steps, seed],
outputs=gallery,
)
demo.launch()
|