File size: 2,774 Bytes
477daa4
 
 
 
 
 
0093752
477daa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from utils import *
import random


is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = [False, False, False, False, False]

def fn_query_on_load():
    return "Cats at sunset"

def fn_refresh():

    return out_img_list


with gr.Blocks() as app:
    with gr.Row():
        gr.Markdown(
            """
            # Stable Diffusion Image Generation
            ### Enter query to generate images in various styles
            """)

    with gr.Row(visible=True):
        with gr.Column():
            with gr.Row():
                search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)

            with gr.Row():
                submit_btn = gr.Button("Submit", variant='primary')
                clear_btn = gr.ClearButton()

    with gr.Row(visible=True):
        output_images = gr.Gallery(value=fn_refresh, interactive=False, every=5)


    def clear_data():
        return {
            output_images: None,
            search_text: None
        }


    clear_btn.click(clear_data, None, [output_images, search_text])


    def func_generate(query):
        global is_clicked
        is_clicked = True
        prompt = query + ' in the style of bulb'
        text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
                               return_tensors="pt")
        input_ids = text_input.input_ids.to(torch_device)

        # Get token embeddings
        position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
        position_embeddings = pos_emb_layer(position_ids)

        s = 0
        for i in range(5):
            token_embeddings = token_emb_layer(input_ids)
            # The new embedding - our special birb word
            replacement_token_embedding = concept_embeds[i].to(torch_device)

            # Insert this into the token embeddings
            token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)

            # Combine with pos embs
            input_embeddings = token_embeddings + position_embeddings

            #  Feed through to get final output embs
            modified_output_embeddings = get_output_embeds(input_embeddings)

            # And generate an image with this:

            s = random.randint(s + 1, s + 30)
            g = torch.manual_seed(s)
            output = generate_with_embs(text_input, modified_output_embeddings, output=out_img_list[i], generator=g)
            #output_images.append(dict(seed=s, output=output))

        is_clicked = False

        return None


    submit_btn.click(
        func_generate,
        [search_text],
        None
    )


'''
Launch the app
'''
app.queue.launch(share=True)