File size: 4,300 Bytes
223c6f1
 
 
 
8be0786
 
223c6f1
 
effa659
 
 
 
 
 
 
89b3a2c
 
 
223c6f1
89b3a2c
223c6f1
 
 
 
 
 
 
89b3a2c
 
 
223c6f1
 
 
8be0786
223c6f1
 
 
 
 
 
 
ef42685
223c6f1
 
 
 
 
 
 
 
 
 
ef42685
223c6f1
 
 
 
 
 
 
 
89b3a2c
 
223c6f1
 
 
 
 
 
 
 
 
8be0786
223c6f1
 
 
 
 
 
 
 
8be0786
89b3a2c
 
 
223c6f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be0786
89b3a2c
8be0786
 
223c6f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr

from model import AppModel

DESCRIPTION = '''# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)

This application accepts English or Chinese as input.
In general, Chinese input produces better results than English input.
But the translation model used in this app may mistranslate and the results could be poor.
So, it is also a good idea to input the translation results from other translation services.
'''
NOTES = '''
- This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
- [This Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) is used for translation from English to Chinese.
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogView2" />'


def set_example_text(example: list) -> dict:
    return gr.Textbox.update(value=example[0])


def main():
    only_first_stage = True
    max_inference_batch_size = 4
    model = AppModel(max_inference_batch_size, only_first_stage)

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    text = gr.Textbox(label='Input Text')
                    translate = gr.Checkbox(label='Translate to Chinese',
                                            value=False)
                    style = gr.Dropdown(choices=[
                        'none',
                        'mainbody',
                        'photo',
                        'flat',
                        'comics',
                        'oil',
                        'sketch',
                        'isometric',
                        'chinese',
                        'watercolor',
                    ],
                                        value='mainbody',
                                        label='Style')
                    seed = gr.Slider(0,
                                     100000,
                                     step=1,
                                     value=1234,
                                     label='Seed')
                    only_first_stage = gr.Checkbox(
                        label='Only First Stage',
                        value=only_first_stage,
                        visible=not only_first_stage)
                    num_images = gr.Slider(1,
                                           16,
                                           step=1,
                                           value=8,
                                           label='Number of Images')
                    with open('samples.txt') as f:
                        samples = [[line.strip()] for line in f.readlines()]
                    examples = gr.Dataset(components=[text], samples=samples)
                    run_button = gr.Button('Run')

            with gr.Column():
                with gr.Group():
                    translated_text = gr.Textbox(label='Translated Text')
                    with gr.Tabs():
                        with gr.TabItem('Output (Grid View)'):
                            result_grid = gr.Image(show_label=False)
                        with gr.TabItem('Output (Gallery)'):
                            result_gallery = gr.Gallery(show_label=False)

        gr.Markdown(NOTES)
        gr.Markdown(FOOTER)

        run_button.click(fn=model.run_with_translation,
                         inputs=[
                             text,
                             translate,
                             style,
                             seed,
                             only_first_stage,
                             num_images,
                         ],
                         outputs=[
                             translated_text,
                             result_grid,
                             result_gallery,
                         ])
        examples.click(fn=set_example_text,
                       inputs=examples,
                       outputs=examples.components)

    demo.launch(enable_queue=True)


if __name__ == '__main__':
    main()