venkat-natchi commited on
Commit
893d436
1 Parent(s): 751f9c2

Update image_generator.py

Browse files
Files changed (1) hide show
  1. image_generator.py +4 -3
image_generator.py CHANGED
@@ -164,7 +164,7 @@ def generate_image_from_embeddings(
164
  loss_selection, additional_prompt):
165
  height = 512
166
  width = 512
167
- num_inference_steps = 20
168
  guidance_scale = 8
169
  generator = torch.manual_seed(seed)
170
  batch_size = 1
@@ -221,7 +221,7 @@ def generate_image_from_embeddings(
221
  )
222
 
223
  #### ADDITIONAL GUIDANCE ###
224
- if i % 3 == 0:
225
  # Requires grad on the latents
226
  latents = latents.detach().requires_grad_()
227
 
@@ -358,7 +358,6 @@ def generate_image_per_prompt_style(text_in, style_in,
358
  style_file = STYLE_EMBEDDINGS[style_in]
359
  print(f"style_file: {style_file}")
360
 
361
- prompt = text_in
362
 
363
  style_seed = STYLE_SEEDS[idx]
364
 
@@ -370,6 +369,8 @@ def generate_image_per_prompt_style(text_in, style_in,
370
  embedding = load_embedding_bin(file_path)
371
  style_key = f"<{style_key}>"
372
 
 
 
373
  gen_style_image = generate_image_per_style(prompt, embedding, style_seed, style_key)
374
 
375
  gen_loss_image = generate_image_per_loss(prompt, embedding, style_seed, style_key, loss, additional_prompt)
 
164
  loss_selection, additional_prompt):
165
  height = 512
166
  width = 512
167
+ num_inference_steps = 50
168
  guidance_scale = 8
169
  generator = torch.manual_seed(seed)
170
  batch_size = 1
 
221
  )
222
 
223
  #### ADDITIONAL GUIDANCE ###
224
+ if i % 5 == 0:
225
  # Requires grad on the latents
226
  latents = latents.detach().requires_grad_()
227
 
 
358
  style_file = STYLE_EMBEDDINGS[style_in]
359
  print(f"style_file: {style_file}")
360
 
 
361
 
362
  style_seed = STYLE_SEEDS[idx]
363
 
 
369
  embedding = load_embedding_bin(file_path)
370
  style_key = f"<{style_key}>"
371
 
372
+ prompt = f"{text_in} {style_key}"
373
+
374
  gen_style_image = generate_image_per_style(prompt, embedding, style_seed, style_key)
375
 
376
  gen_loss_image = generate_image_per_loss(prompt, embedding, style_seed, style_key, loss, additional_prompt)