File size: 3,815 Bytes
477daa4
 
 
 
 
 
0093752
477daa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d13c68f
 
 
 
 
 
 
 
 
477daa4
d13c68f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477daa4
 
 
d13c68f
 
 
 
 
477daa4
 
 
 
d13c68f
477daa4
 
d13c68f
477daa4
 
 
 
 
 
 
 
 
d13c68f
 
 
 
 
477daa4
d13c68f
 
477daa4
d13c68f
 
477daa4
d13c68f
 
477daa4
d13c68f
477daa4
d13c68f
 
 
 
 
 
 
 
 
 
477daa4
d13c68f
 
 
 
 
477daa4
d13c68f
 
 
 
 
477daa4
d13c68f
 
 
 
 
477daa4
d13c68f
477daa4
d13c68f
 
477daa4
 
 
 
 
 
d13c68f
477daa4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from utils import *
import random


is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = [False, False, False, False, False]

def fn_query_on_load():
    return "Cats at sunset"

def fn_refresh():

    return out_img_list


with gr.Blocks() as app:
    with gr.Row():
        gr.Markdown(
            """
            # Stable Diffusion Image Generation
            ### Enter query to generate images in various styles
            """)

    with gr.Row(visible=True):
        with gr.Column():
            with gr.Row():
                search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)


    with gr.Row(visible=True):
        with gr.Column():
            with gr.Row():
                out1 = gr.Image(interactive=False, label='Oil Painting')
                submit1 = gr.Button("submit", variant='primary')

        with gr.Column():
            with gr.Row():
                out2 = gr.Image(interactive=False, label='Low Poly HD Style')
                submit2 = gr.Button("submit", variant='primary')

        with gr.Column():
            with gr.Row():
                out3 = gr.Image(interactive=False, label='Matrix style')
                submit3 = gr.Button("submit", variant='primary')


        with gr.Column():
            with gr.Row():
                out4 = gr.Image(interactive=False, label='Dreamy Painting')
                submit4 = gr.Button("submit", variant='primary')

        with gr.Column():
            with gr.Row():
                out5 = gr.Image(interactive=False, label='Depth Map Style')
                submit5 = gr.Button("submit", variant='primary')


    with gr.Row(visible=True):
        clear_btn = gr.ClearButton()

    def clear_data():
        return {
            out1: None,
            out2: None,
            out3: None,
            out4: None,
            out5: None,
            search_text: None
        }


    clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])


    def func_generate(query, concept_idx, seed):
        prompt = query + ' in the style of bulb'
        text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
                               return_tensors="pt")
        input_ids = text_input.input_ids.to(torch_device)

        # Get token embeddings
        position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
        position_embeddings = pos_emb_layer(position_ids)

        s = seed

        token_embeddings = token_emb_layer(input_ids)
        # The new embedding - our special birb word
        replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)

        # Insert this into the token embeddings
        token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)

        # Combine with pos embs
        input_embeddings = token_embeddings + position_embeddings

        #  Feed through to get final output embs
        modified_output_embeddings = get_output_embeds(input_embeddings)

        # And generate an image with this:

        s = random.randint(s + 1, s + 30)
        g = torch.manual_seed(s)
        return generate_with_embs(text_input, modified_output_embeddings, generator=g)


    submit1.click(
        func_generate,
        [search_text, 0, 0],
        out1
    )

    submit2.click(
        func_generate,
        [search_text, 1, 30],
        out2
    )

    submit3.click(
        func_generate,
        [search_text, 2, 60],
        out3
    )

    submit4.click(
        func_generate,
        [search_text, 3, 90],
        out4
    )

    submit5.click(
        func_generate,
        [search_text, 4, 120],
        out5
    )


'''
Launch the app
'''
app.launch()