MilindChawre commited on
Commit
6304c5b
·
1 Parent(s): 45b110b

Modifying the app code

Browse files
Files changed (1) hide show
  1. app.py +127 -83
app.py CHANGED
@@ -67,6 +67,12 @@ def image_loss(images, loss_type, device, elastic_transformer):
67
  else:
68
  return torch.tensor(0.0).to(device)
69
 
 
 
 
 
 
 
70
  def generate_images(prompt, concept):
71
  global pipe, device, elastic_transformer
72
  if pipe is None:
@@ -74,96 +80,133 @@ def generate_images(prompt, concept):
74
  if elastic_transformer is None:
75
  elastic_transformer = init_transformers(device)
76
 
77
- # Configuration
78
- height, width = 384, 384
79
- guidance_scale = 8
80
- num_inference_steps = 45
81
- loss_scale = 10.0
82
-
83
- # Create scheduler
84
- scheduler = LMSDiscreteScheduler(
85
- beta_start=0.00085,
86
- beta_end=0.012,
87
- beta_schedule="scaled_linear",
88
- num_train_timesteps=1000
89
- )
90
- pipe.scheduler = scheduler # Set the scheduler
91
-
92
- # Create prompt text
93
  prompt_text = f"{prompt} {concept}"
 
94
 
95
- # Predefined seeds for each loss function
96
- seeds = {
97
- 'none': 42,
98
- 'blue': 123,
99
- 'elastic': 456,
100
- 'symmetry': 789,
101
- 'saturation': 1000
102
- }
103
-
104
  loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
105
- images = []
106
  progress = gr.Progress()
107
 
108
- # Generate image for each loss function
109
  for idx, loss_type in enumerate(loss_functions):
110
  progress(idx/len(loss_functions), f"Generating {loss_type} image...")
111
- generator = torch.manual_seed(seeds[loss_type])
112
 
113
- # Generate base image
114
  try:
115
- output = pipe(
116
- prompt_text,
117
- height=height,
118
- width=width,
119
- num_inference_steps=num_inference_steps,
120
- guidance_scale=guidance_scale,
121
- generator=generator
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
- except Exception as e:
124
- print(f"Error generating image: {e}")
125
- return None
126
-
127
- # Apply loss function if not 'none'
128
- if loss_type != 'none':
129
- try:
130
- # Convert PIL image to tensor and move to device
131
- image_tensor = T.ToTensor()(output.images[0]).unsqueeze(0).to(device)
132
- # Apply loss and update image
133
- loss = image_loss(image_tensor, loss_type, device, elastic_transformer)
134
- image_tensor = image_tensor - loss_scale * loss
135
- # Move back to CPU and convert to PIL
136
- image = T.ToPILImage()(image_tensor.cpu().squeeze(0).clamp(0, 1))
137
- except Exception as e:
138
- print(f"Error applying {loss_type} loss: {e}")
139
- image = output.images[0] # Use original image if loss fails
140
- else:
141
- image = output.images[0]
142
-
143
- # Add image with its label
144
- try:
145
- # Ensure image is in correct format (PIL.Image)
146
- if not isinstance(image, Image.Image):
147
- print(f"Warning: Converting {loss_type} image to PIL format")
148
- image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- # Add tuple of (image, label) to list
151
- images.append((image, f"{loss_type.capitalize()} Loss"))
152
- print(f"Added {loss_type} image to gallery") # Debug print
153
  except Exception as e:
154
- print(f"Error adding {loss_type} image to gallery: {e}")
155
- continue
156
-
157
- # Clear GPU memory after each image
158
- if torch.cuda.is_available():
159
- torch.cuda.empty_cache()
160
- gc.collect()
161
 
162
- # Return all generated images
163
- print(f"Returning {len(images)} images")
164
- if not images:
 
 
 
 
165
  return None
166
- return images
167
 
168
  def create_interface():
169
  default_prompts = [
@@ -187,14 +230,13 @@ def create_interface():
187
  gr.Dropdown(choices=concepts, label="Select SD Concept")
188
  ],
189
  outputs=gr.Gallery(
190
- label="Generated Images (From Left to Right: Original, Blue Loss, Elastic Loss, Symmetry Loss, Saturation Loss)",
191
  show_label=True,
192
  elem_id="gallery",
193
  columns=5,
194
  rows=1,
195
- height=512,
196
- object_fit="contain"
197
- ), # Simplified Gallery definition
198
  title="Stable Diffusion using Text Inversion",
199
  description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side:
200
  1. Original Image (No Loss)
@@ -204,16 +246,18 @@ def create_interface():
204
  5. Saturation Loss - Modifies color saturation
205
 
206
  Note: Image generation may take several minutes. Please be patient while the images are being processed.""",
207
- flagging_mode="never" # Updated from allow_flagging
 
 
208
  )
209
 
210
  return interface
211
 
212
  if __name__ == "__main__":
213
  interface = create_interface()
214
- interface.queue(max_size=5) # Simplified queue configuration
215
  interface.launch(
216
  share=True,
217
  server_name="0.0.0.0",
218
- max_threads=1
219
  )
 
67
  else:
68
  return torch.tensor(0.0).to(device)
69
 
70
+ # Update configuration
71
+ height, width = 512, 512
72
+ guidance_scale = 8
73
+ num_inference_steps = 50
74
+ loss_scale = 200
75
+
76
  def generate_images(prompt, concept):
77
  global pipe, device, elastic_transformer
78
  if pipe is None:
 
80
  if elastic_transformer is None:
81
  elastic_transformer = init_transformers(device)
82
 
83
+ # Create prompt text and initialize results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  prompt_text = f"{prompt} {concept}"
85
+ all_images = [] # Changed from images to all_images
86
 
87
+ # Process each loss type
 
 
 
 
 
 
 
 
88
  loss_functions = ['none', 'blue', 'elastic', 'symmetry', 'saturation']
 
89
  progress = gr.Progress()
90
 
 
91
  for idx, loss_type in enumerate(loss_functions):
92
  progress(idx/len(loss_functions), f"Generating {loss_type} image...")
 
93
 
 
94
  try:
95
+ # Better memory management
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+ gc.collect()
99
+ torch.cuda.empty_cache()
100
+
101
+ # Move inputs to correct device and dtype
102
+ # Remove incorrect device movement
103
+ # text_input = text_input.to(device) # Remove this line
104
+ # uncond_input = uncond_input.to(device) # Remove this line
105
+ # latents = latents.to(dtype=pipe.vae.dtype, device=device) # Remove this line
106
+
107
+ # Initialize scheduler and process text first
108
+ scheduler = LMSDiscreteScheduler(
109
+ beta_start=0.00085,
110
+ beta_end=0.012,
111
+ beta_schedule="scaled_linear",
112
+ num_train_timesteps=1000
113
  )
114
+ scheduler.set_timesteps(num_inference_steps)
115
+ scheduler.timesteps = scheduler.timesteps.to(device)
116
+
117
+ # Process text embeddings
118
+ text_input = pipe.tokenizer(
119
+ [prompt_text],
120
+ padding='max_length',
121
+ max_length=pipe.tokenizer.model_max_length,
122
+ truncation=True,
123
+ return_tensors="pt"
124
+ )
125
+
126
+ with torch.no_grad():
127
+ text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
128
+
129
+ uncond_input = pipe.tokenizer(
130
+ [""] * 1,
131
+ padding="max_length",
132
+ max_length=text_input.input_ids.shape[-1],
133
+ return_tensors="pt"
134
+ )
135
+
136
+ with torch.no_grad():
137
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
138
+
139
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
140
+
141
+ # Generate initial latents with correct dtype
142
+ generator = torch.manual_seed(idx * 1000)
143
+ latents = torch.randn(
144
+ (1, pipe.unet.config.in_channels, height // 8, width // 8),
145
+ generator=generator,
146
+ )
147
+ latents = latents.to(device=device, dtype=pipe.unet.dtype)
148
+ latents = latents * scheduler.init_noise_sigma
149
+
150
+ # Diffusion process
151
+ for i, t in enumerate(scheduler.timesteps):
152
+ latent_model_input = torch.cat([latents] * 2)
153
+ sigma = scheduler.sigmas[i]
154
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
155
+
156
+ # Move latent_model_input to correct dtype
157
+ latent_model_input = latent_model_input.to(dtype=pipe.unet.dtype)
158
+
159
+ with torch.no_grad():
160
+ noise_pred = pipe.unet(
161
+ latent_model_input,
162
+ t,
163
+ encoder_hidden_states=text_embeddings
164
+ )["sample"]
165
+
166
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
167
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
168
+
169
+ # Apply loss every 5 steps if not 'none'
170
+ if loss_type != 'none' and i % 5 == 0:
171
+ latents = latents.detach().requires_grad_()
172
+ latents_x0 = latents - sigma * noise_pred
173
+
174
+ # Decode to image space for loss computation
175
+ with torch.set_grad_enabled(True): # Enable gradients for loss computation
176
+ denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
177
+ denoised_images = denoised_images.requires_grad_() # Enable gradients for images
178
+ loss = image_loss(denoised_images, loss_type, device, elastic_transformer)
179
+ cond_grad = torch.autograd.grad(loss * loss_scale, latents)[0]
180
+
181
+ latents = latents.detach() - cond_grad * sigma**2
182
+
183
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
184
+
185
+ # Proper latent to image conversion
186
+ latents = (1 / 0.18215) * latents
187
+ with torch.no_grad():
188
+ image = pipe.vae.decode(latents).sample
189
+
190
+ image = (image / 2 + 0.5).clamp(0, 1)
191
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
192
+ image = (image * 255).round().astype("uint8")
193
+ pil_image = Image.fromarray(image[0])
194
+
195
+ # Add image with its label
196
+ all_images.append((pil_image, f"{loss_type.capitalize()} Loss"))
197
 
 
 
 
198
  except Exception as e:
199
+ print(f"Error generating {loss_type} image: {e}")
200
+ continue # Continue to next loss type instead of returning None
 
 
 
 
 
201
 
202
+ # At the end of the function
203
+ try:
204
+ if len(all_images) == 0:
205
+ raise Exception("No images were generated successfully")
206
+ return [img for img, _ in all_images]
207
+ except Exception as e:
208
+ print(f"Error in generate_images: {e}")
209
  return None
 
210
 
211
  def create_interface():
212
  default_prompts = [
 
230
  gr.Dropdown(choices=concepts, label="Select SD Concept")
231
  ],
232
  outputs=gr.Gallery(
233
+ label="Generated Images",
234
  show_label=True,
235
  elem_id="gallery",
236
  columns=5,
237
  rows=1,
238
+ height="auto"
239
+ ),
 
240
  title="Stable Diffusion using Text Inversion",
241
  description="""Generate images using Stable Diffusion with different style concepts. The output shows 5 images side by side:
242
  1. Original Image (No Loss)
 
246
  5. Saturation Loss - Modifies color saturation
247
 
248
  Note: Image generation may take several minutes. Please be patient while the images are being processed.""",
249
+ cache_examples=False,
250
+ max_batch_size=1,
251
+ flagging_mode="never"
252
  )
253
 
254
  return interface
255
 
256
  if __name__ == "__main__":
257
  interface = create_interface()
258
+ interface.queue(max_size=5) # Remove concurrency_count parameter
259
  interface.launch(
260
  share=True,
261
  server_name="0.0.0.0",
262
+ server_port=7860
263
  )