Spaces:
Runtime error
Runtime error
Commit
Β·
266993c
1
Parent(s):
433ae3e
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from utils import
|
3 |
import random
|
4 |
|
5 |
is_clicked = False
|
@@ -89,43 +89,6 @@ with gr.Blocks() as app:
|
|
89 |
clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
|
90 |
|
91 |
|
92 |
-
def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
|
93 |
-
prompt = query + ' in the style of bulb'
|
94 |
-
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
|
95 |
-
return_tensors="pt")
|
96 |
-
input_ids = text_input.input_ids.to(torch_device)
|
97 |
-
|
98 |
-
# Get token embeddings
|
99 |
-
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
|
100 |
-
position_embeddings = pos_emb_layer(position_ids)
|
101 |
-
|
102 |
-
s = seed_start
|
103 |
-
|
104 |
-
token_embeddings = token_emb_layer(input_ids)
|
105 |
-
# The new embedding - our special birb word
|
106 |
-
replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
|
107 |
-
|
108 |
-
# Insert this into the token embeddings
|
109 |
-
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
|
110 |
-
|
111 |
-
# Combine with pos embs
|
112 |
-
input_embeddings = token_embeddings + position_embeddings
|
113 |
-
|
114 |
-
# Feed through to get final output embs
|
115 |
-
modified_output_embeddings = get_output_embeds(input_embeddings)
|
116 |
-
|
117 |
-
# And generate an image with this:
|
118 |
-
|
119 |
-
if contrast_loss and seed_values[concept_idx] > 0:
|
120 |
-
s = seed_values[concept_idx]
|
121 |
-
else:
|
122 |
-
s = random.randint(s + 1, s + 30)
|
123 |
-
seed_values[concept_idx] = s
|
124 |
-
|
125 |
-
g = torch.manual_seed(s)
|
126 |
-
return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)
|
127 |
-
|
128 |
-
|
129 |
def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
|
130 |
if not query:
|
131 |
raise gr.Error("No prompt provided")
|
|
|
1 |
import gradio as gr
|
2 |
+
from utils import func_generate
|
3 |
import random
|
4 |
|
5 |
is_clicked = False
|
|
|
89 |
clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
|
93 |
if not query:
|
94 |
raise gr.Error("No prompt provided")
|