File size: 2,324 Bytes
4cd9bad
384ec64
 
 
4cd9bad
384ec64
 
 
4cd9bad
384ec64
4cd9bad
384ec64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from torchvision.transforms.functional import to_pil_image

# pipeline = StableDiffusionPipeline.from_pretrained(
#     pretrained_model_name_or_path="weights", torch_dtype=torch.float16
# )

# pipeline.to('cuda')

concept_to_dummy_tokens_map = torch.load("concept_to_dummy_tokens_map.pt")


def replace_concept_tokens(text: str):
    for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
        text = text.replace(concept_token, dummy_tokens)
    return text


# def inference(
#     prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0
# ):
#     prompt = replace_concept_tokens(prompt)
#     for _ in range(3):
#         img_list = pipeline(
#             prompt=prompt,
#             num_inference_steps=num_inference_steps,
#             guidance_scale=guidance_scale,
#         )
#         if not img_list["nsfw_content_detected"]:
#             break
#     return img_list["sample"]

DEFAULT_PROMPT = (
    "A watercolor painting on textured paper of a <det-logo> using soft strokes,"
    " pastel colors, incredible composition, masterpiece"
)


def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int):
    return [torch.ones(512, 512, 3).numpy() for _ in range(2)]


with gr.Blocks() as demo:
    prompt = gr.Textbox(
        label="Prompt including the token '<det-logo>'",
        placeholder=DEFAULT_PROMPT,
        interactive=True,
    )
    guidance_scale = gr.Slider(
        minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True
    )
    num_inference_steps = gr.Slider(
        minimum=25,
        maximum=60,
        value=40,
        label="Num Inference Steps",
        interactive=True,
        step=1,
    )
    seed = gr.Slider(
        minimum=2147483147,
        maximum=2147483647,
        value=2147483397,
        label="Seed",
        interactive=True,
    )
    generate_btn = gr.Button(label="Generate")
    gallery = gr.Gallery(
        label="Generated Images",
        value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)],
    ).style(height="auto")

    generate_btn.click(
        white_imgs,
        inputs=[prompt, guidance_scale, num_inference_steps, seed],
        outputs=gallery,
    )

demo.launch()