piyushgrover commited on
Commit
266993c
Β·
1 Parent(s): 433ae3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -38
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from utils import *
3
  import random
4
 
5
  is_clicked = False
@@ -89,43 +89,6 @@ with gr.Blocks() as app:
89
  clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
90
 
91
 
92
- def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
93
- prompt = query + ' in the style of bulb'
94
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
95
- return_tensors="pt")
96
- input_ids = text_input.input_ids.to(torch_device)
97
-
98
- # Get token embeddings
99
- position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
100
- position_embeddings = pos_emb_layer(position_ids)
101
-
102
- s = seed_start
103
-
104
- token_embeddings = token_emb_layer(input_ids)
105
- # The new embedding - our special birb word
106
- replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
107
-
108
- # Insert this into the token embeddings
109
- token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
110
-
111
- # Combine with pos embs
112
- input_embeddings = token_embeddings + position_embeddings
113
-
114
- # Feed through to get final output embs
115
- modified_output_embeddings = get_output_embeds(input_embeddings)
116
-
117
- # And generate an image with this:
118
-
119
- if contrast_loss and seed_values[concept_idx] > 0:
120
- s = seed_values[concept_idx]
121
- else:
122
- s = random.randint(s + 1, s + 30)
123
- seed_values[concept_idx] = s
124
-
125
- g = torch.manual_seed(s)
126
- return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)
127
-
128
-
129
  def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
130
  if not query:
131
  raise gr.Error("No prompt provided")
 
1
  import gradio as gr
2
+ from utils import func_generate
3
  import random
4
 
5
  is_clicked = False
 
89
  clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
93
  if not query:
94
  raise gr.Error("No prompt provided")