File size: 5,638 Bytes
4fb3c5e
 
 
 
 
 
 
 
 
 
 
70d5056
4fb3c5e
 
 
ea424ac
4fb3c5e
ea424ac
4fb3c5e
 
 
 
 
7d37aeb
 
ea424ac
 
7d37aeb
 
ea424ac
4fb3c5e
 
ea424ac
 
 
 
7d37aeb
 
ea424ac
 
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
1489344
4fb3c5e
 
61d3740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61d3740
4fb3c5e
61d3740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fb3c5e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python

from __future__ import annotations

import argparse

import gradio as gr

from model import Model

TITLE = '# Anime Face Generation with [Diffusers](https://github.com/huggingface/diffusers)'
DESCRIPTION = 'Expected execution time on Hugging Face Spaces: 5s (DDIM, 20 steps), 6s (PNDM, 20 steps), 247s (DDPM, 1000 steps)'
FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.diffusers-anime-faces" alt="visitor badge" />'


def get_sample_image_url(file_name: str) -> str:
    sample_image_dir = 'https://huggingface.co/spaces/hysts/diffusers-anime-faces/resolve/main/samples'
    return f'{sample_image_dir}/{file_name}'


def get_sample_image_markdown(name: str) -> str:
    model_name = name.split()[0]
    if name == 'ddpm-128-exp000 (DDPM)':
        scheduler = 'DDPM'
        steps = 1000
        file_name = f'{model_name}.png'
    elif name == 'ddpm-128-exp000 (DDIM, 20 steps)':
        scheduler = 'DDIM'
        steps = 20
        file_name = f'{model_name}-ddim-20steps.png'
    else:
        raise ValueError
    url = get_sample_image_url(file_name)
    text = f'''
            - size: 128x128
            - seed: 0-99
            - scheduler: {scheduler}
            - steps: {steps}

            ![sample images]({url})'''
    return text


def update_scheduler_type(name: str) -> dict:
    visible = name != 'DDPM'
    if name == 'PNDM':
        minimum = 4
        maximum = 100
    else:
        minimum = 1
        maximum = 200
    return gr.Slider.update(visible=visible,
                            minimum=minimum,
                            maximum=maximum,
                            value=20)


def create_simple_demo(model: Model) -> gr.Blocks:
    with gr.Blocks() as demo:
        run_button = gr.Button('Generate')
        result = gr.Image(show_label=False, elem_id='result-grid')
        run_button.click(fn=model.run_simple, inputs=None, outputs=result)
    return demo


def create_advanced_demo(model: Model) -> gr.Blocks:
    with gr.Blocks() as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    model_name = gr.Dropdown(model.MODEL_NAMES,
                                             value=model.MODEL_NAMES[0],
                                             label='Model',
                                             interactive=False)
                    scheduler_type = gr.Radio(choices=['DDPM', 'DDIM', 'PNDM'],
                                              value='DDIM',
                                              label='Scheduler')
                    num_steps = gr.Slider(1,
                                          200,
                                          step=1,
                                          value=20,
                                          label='Number of Steps')
                    seed = gr.Slider(0,
                                     100000,
                                     step=1,
                                     value=1234,
                                     label='Seed')
                    run_button = gr.Button('Run')
            with gr.Column():
                result = gr.Image(show_label=False, elem_id='result')

        model_name.change(fn=model.set_pipeline,
                          inputs=[
                              model_name,
                              scheduler_type,
                          ],
                          outputs=None)
        scheduler_type.change(fn=update_scheduler_type,
                              inputs=scheduler_type,
                              outputs=num_steps,
                              queue=False)
        scheduler_type.change(fn=model.set_pipeline,
                              inputs=[
                                  model_name,
                                  scheduler_type,
                              ],
                              outputs=None)
        run_button.click(fn=model.run,
                         inputs=[
                             model_name,
                             scheduler_type,
                             num_steps,
                             seed,
                         ],
                         outputs=result)
    return demo


def create_sample_image_view_demo() -> gr.Blocks:
    with gr.Blocks() as demo:
        with gr.Row():
            model_name = gr.Dropdown([
                'ddpm-128-exp000 (DDPM)',
                'ddpm-128-exp000 (DDIM, 20 steps)',
            ],
                                     value='ddpm-128-exp000 (DDPM)',
                                     label='Model')
        with gr.Row():
            text = get_sample_image_markdown(model_name.value)
            sample_images = gr.Markdown(text)

        model_name.change(fn=get_sample_image_markdown,
                          inputs=model_name,
                          outputs=sample_images)
    return demo


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()
    model = Model(args.device)

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(TITLE)
        with gr.Tabs():
            with gr.TabItem('Simple Mode'):
                create_simple_demo(model)
            with gr.TabItem('Advanced Mode'):
                create_advanced_demo(model)
            with gr.TabItem('Sample Images'):
                create_sample_image_view_demo()
        gr.Markdown(FOOTER)
    demo.launch(enable_queue=True, share=False)


if __name__ == '__main__':
    main()