MilindChawre commited on
Commit
663ae65
·
1 Parent(s): 269022b

Making changes in random seed value generation

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -50,8 +50,10 @@ def init_transformers(device):
50
  # Add after init_transformers and before generate_images
51
  def image_loss(images, loss_type, device, elastic_transformer):
52
  if loss_type == 'blue':
53
- error = torch.abs(images[:,2] - 0.9).mean()
54
- return error.to(device)
 
 
55
  elif loss_type == 'elastic':
56
  transformed_imgs = elastic_transformer(images)
57
  error = torch.abs(transformed_imgs - images).mean()
@@ -69,8 +71,8 @@ def image_loss(images, loss_type, device, elastic_transformer):
69
 
70
  # Update configuration for faster generation
71
  height, width = 384, 384 # Reduced from 512x512 to 384x384
72
- guidance_scale = 7.5
73
- num_inference_steps = 30
74
  loss_scale = 150
75
 
76
  def generate_images(prompt, concept):
@@ -88,6 +90,10 @@ def generate_images(prompt, concept):
88
  loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
89
  progress = gr.Progress()
90
 
 
 
 
 
91
  for idx, loss_type in enumerate(loss_functions):
92
  try:
93
  print(f"\n[{loss_type.upper()}] Starting image generation...")
@@ -139,8 +145,9 @@ def generate_images(prompt, concept):
139
 
140
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
141
 
142
- # Generate initial latents with correct dtype
143
- generator = torch.manual_seed(idx * 1000)
 
144
  latents = torch.randn(
145
  (1, pipe.unet.config.in_channels, height // 8, width // 8),
146
  generator=generator,
 
50
  # Add after init_transformers and before generate_images
51
  def image_loss(images, loss_type, device, elastic_transformer):
52
  if loss_type == 'blue':
53
+ # Reduced target blue value from 0.9 to 0.6 for more subtle effect
54
+ error = torch.abs(images[:,2] - 0.6).mean()
55
+ # Apply a lower scale specifically for blue loss
56
+ return (error * 0.3).to(device) # Reduced scaling factor
57
  elif loss_type == 'elastic':
58
  transformed_imgs = elastic_transformer(images)
59
  error = torch.abs(transformed_imgs - images).mean()
 
71
 
72
  # Update configuration for faster generation
73
  height, width = 384, 384 # Reduced from 512x512 to 384x384
74
+ guidance_scale = 8 # Increased from 7.5 to 8 for better prompt adherence
75
+ num_inference_steps = 45 # Using 45 steps for better quality
76
  loss_scale = 150
77
 
78
  def generate_images(prompt, concept):
 
90
  loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
91
  progress = gr.Progress()
92
 
93
+ # Generate one random seed for all loss types
94
+ random_seed = torch.randint(1, 10000, (1,)).item() # Random seed between 1 and 9999
95
+ print(f"\nUsing random seed {random_seed} for all images")
96
+
97
  for idx, loss_type in enumerate(loss_functions):
98
  try:
99
  print(f"\n[{loss_type.upper()}] Starting image generation...")
 
145
 
146
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
147
 
148
+ # Generate initial latents with random seed
149
+ # Use the same seed for all loss types
150
+ generator = torch.manual_seed(random_seed)
151
  latents = torch.randn(
152
  (1, pipe.unet.config.in_channels, height // 8, width // 8),
153
  generator=generator,