Spaces:
Runtime error
Runtime error
File size: 2,774 Bytes
477daa4 0093752 477daa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
from utils import *
import random
is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = [False, False, False, False, False]
def fn_query_on_load():
return "Cats at sunset"
def fn_refresh():
return out_img_list
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# Stable Diffusion Image Generation
### Enter query to generate images in various styles
""")
with gr.Row(visible=True):
with gr.Column():
with gr.Row():
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.ClearButton()
with gr.Row(visible=True):
output_images = gr.Gallery(value=fn_refresh, interactive=False, every=5)
def clear_data():
return {
output_images: None,
search_text: None
}
clear_btn.click(clear_data, None, [output_images, search_text])
def func_generate(query):
global is_clicked
is_clicked = True
prompt = query + ' in the style of bulb'
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
return_tensors="pt")
input_ids = text_input.input_ids.to(torch_device)
# Get token embeddings
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
position_embeddings = pos_emb_layer(position_ids)
s = 0
for i in range(5):
token_embeddings = token_emb_layer(input_ids)
# The new embedding - our special birb word
replacement_token_embedding = concept_embeds[i].to(torch_device)
# Insert this into the token embeddings
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
# Combine with pos embs
input_embeddings = token_embeddings + position_embeddings
# Feed through to get final output embs
modified_output_embeddings = get_output_embeds(input_embeddings)
# And generate an image with this:
s = random.randint(s + 1, s + 30)
g = torch.manual_seed(s)
output = generate_with_embs(text_input, modified_output_embeddings, output=out_img_list[i], generator=g)
#output_images.append(dict(seed=s, output=output))
is_clicked = False
return None
submit_btn.click(
func_generate,
[search_text],
None
)
'''
Launch the app
'''
app.queue.launch(share=True)
|