Commit
·
663ae65
1
Parent(s):
269022b
Making changes in random seed value generation
Browse files
app.py
CHANGED
@@ -50,8 +50,10 @@ def init_transformers(device):
|
|
50 |
# Add after init_transformers and before generate_images
|
51 |
def image_loss(images, loss_type, device, elastic_transformer):
|
52 |
if loss_type == 'blue':
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
elif loss_type == 'elastic':
|
56 |
transformed_imgs = elastic_transformer(images)
|
57 |
error = torch.abs(transformed_imgs - images).mean()
|
@@ -69,8 +71,8 @@ def image_loss(images, loss_type, device, elastic_transformer):
|
|
69 |
|
70 |
# Update configuration for faster generation
|
71 |
height, width = 384, 384 # Reduced from 512x512 to 384x384
|
72 |
-
guidance_scale = 7.5
|
73 |
-
num_inference_steps =
|
74 |
loss_scale = 150
|
75 |
|
76 |
def generate_images(prompt, concept):
|
@@ -88,6 +90,10 @@ def generate_images(prompt, concept):
|
|
88 |
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
|
89 |
progress = gr.Progress()
|
90 |
|
|
|
|
|
|
|
|
|
91 |
for idx, loss_type in enumerate(loss_functions):
|
92 |
try:
|
93 |
print(f"\n[{loss_type.upper()}] Starting image generation...")
|
@@ -139,8 +145,9 @@ def generate_images(prompt, concept):
|
|
139 |
|
140 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
141 |
|
142 |
-
# Generate initial latents with
|
143 |
-
|
|
|
144 |
latents = torch.randn(
|
145 |
(1, pipe.unet.config.in_channels, height // 8, width // 8),
|
146 |
generator=generator,
|
|
|
50 |
# Add after init_transformers and before generate_images
|
51 |
def image_loss(images, loss_type, device, elastic_transformer):
|
52 |
if loss_type == 'blue':
|
53 |
+
# Reduced target blue value from 0.9 to 0.6 for more subtle effect
|
54 |
+
error = torch.abs(images[:,2] - 0.6).mean()
|
55 |
+
# Apply a lower scale specifically for blue loss
|
56 |
+
return (error * 0.3).to(device) # Reduced scaling factor
|
57 |
elif loss_type == 'elastic':
|
58 |
transformed_imgs = elastic_transformer(images)
|
59 |
error = torch.abs(transformed_imgs - images).mean()
|
|
|
71 |
|
72 |
# Update configuration for faster generation
|
73 |
height, width = 384, 384 # Reduced from 512x512 to 384x384
|
74 |
+
guidance_scale = 8 # Increased from 7.5 to 8 for better prompt adherence
|
75 |
+
num_inference_steps = 45 # Using 45 steps for better quality
|
76 |
loss_scale = 150
|
77 |
|
78 |
def generate_images(prompt, concept):
|
|
|
90 |
loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
|
91 |
progress = gr.Progress()
|
92 |
|
93 |
+
# Generate one random seed for all loss types
|
94 |
+
random_seed = torch.randint(1, 10000, (1,)).item() # Random seed between 1 and 9999
|
95 |
+
print(f"\nUsing random seed {random_seed} for all images")
|
96 |
+
|
97 |
for idx, loss_type in enumerate(loss_functions):
|
98 |
try:
|
99 |
print(f"\n[{loss_type.upper()}] Starting image generation...")
|
|
|
145 |
|
146 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
147 |
|
148 |
+
# Generate initial latents with random seed
|
149 |
+
# Use the same seed for all loss types
|
150 |
+
generator = torch.manual_seed(random_seed)
|
151 |
latents = torch.randn(
|
152 |
(1, pipe.unet.config.in_channels, height // 8, width // 8),
|
153 |
generator=generator,
|