Spaces:
Runtime error
Runtime error
File size: 3,815 Bytes
477daa4 0093752 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 d13c68f 477daa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from utils import *
import random
is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = [False, False, False, False, False]
def fn_query_on_load():
return "Cats at sunset"
def fn_refresh():
return out_img_list
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# Stable Diffusion Image Generation
### Enter query to generate images in various styles
""")
with gr.Row(visible=True):
with gr.Column():
with gr.Row():
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
with gr.Row(visible=True):
with gr.Column():
with gr.Row():
out1 = gr.Image(interactive=False, label='Oil Painting')
submit1 = gr.Button("submit", variant='primary')
with gr.Column():
with gr.Row():
out2 = gr.Image(interactive=False, label='Low Poly HD Style')
submit2 = gr.Button("submit", variant='primary')
with gr.Column():
with gr.Row():
out3 = gr.Image(interactive=False, label='Matrix style')
submit3 = gr.Button("submit", variant='primary')
with gr.Column():
with gr.Row():
out4 = gr.Image(interactive=False, label='Dreamy Painting')
submit4 = gr.Button("submit", variant='primary')
with gr.Column():
with gr.Row():
out5 = gr.Image(interactive=False, label='Depth Map Style')
submit5 = gr.Button("submit", variant='primary')
with gr.Row(visible=True):
clear_btn = gr.ClearButton()
def clear_data():
return {
out1: None,
out2: None,
out3: None,
out4: None,
out5: None,
search_text: None
}
clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])
def func_generate(query, concept_idx, seed):
prompt = query + ' in the style of bulb'
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
return_tensors="pt")
input_ids = text_input.input_ids.to(torch_device)
# Get token embeddings
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
position_embeddings = pos_emb_layer(position_ids)
s = seed
token_embeddings = token_emb_layer(input_ids)
# The new embedding - our special birb word
replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
# Insert this into the token embeddings
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
# Combine with pos embs
input_embeddings = token_embeddings + position_embeddings
# Feed through to get final output embs
modified_output_embeddings = get_output_embeds(input_embeddings)
# And generate an image with this:
s = random.randint(s + 1, s + 30)
g = torch.manual_seed(s)
return generate_with_embs(text_input, modified_output_embeddings, generator=g)
submit1.click(
func_generate,
[search_text, 0, 0],
out1
)
submit2.click(
func_generate,
[search_text, 1, 30],
out2
)
submit3.click(
func_generate,
[search_text, 2, 60],
out3
)
submit4.click(
func_generate,
[search_text, 3, 90],
out4
)
submit5.click(
func_generate,
[search_text, 4, 120],
out5
)
'''
Launch the app
'''
app.launch()
|