Vijish commited on
Commit
47273de
Β·
verified Β·
1 Parent(s): fb048af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -15
app.py CHANGED
@@ -12,7 +12,6 @@ import os
12
  import random
13
  import gc
14
 
15
-
16
  def clear_memory():
17
  gc.collect()
18
  if torch.cuda.is_available():
@@ -208,6 +207,8 @@ def load_controlnet_model(controlnet_type):
208
 
209
  # Preprocessing functions for each ControlNet type
210
  def preprocess_canny(image):
 
 
211
  if isinstance(image, Image.Image):
212
  image = np.array(image)
213
  if image.dtype != np.uint8:
@@ -217,8 +218,9 @@ def preprocess_canny(image):
217
  image = np.concatenate([image, image, image], axis=2)
218
  return Image.fromarray(image)
219
 
220
-
221
  def preprocess_depth(image, target_size=(1024, 1024)):
 
 
222
  if isinstance(image, Image.Image):
223
  img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
224
  else:
@@ -234,6 +236,8 @@ def preprocess_depth(image, target_size=(1024, 1024)):
234
  return Image.fromarray(depth_img_resized)
235
 
236
  def preprocess_openpose(image):
 
 
237
  if isinstance(image, Image.Image):
238
  image = np.array(image)
239
  image = openpose_processor(image, hand_and_face=False, output_type='cv2')
@@ -273,21 +277,24 @@ def process_image_batch(images, pipe, prompt, negative_prompt, progress, batch_s
273
  clear_memory() # Clear memory after each batch
274
  return all_processed_images
275
 
276
-
277
  # Define the function to generate images
278
- def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input, progress=gr.Progress(track_tqdm=True)):
279
  global controlnet_pipe, pipe, reference_pipe
280
 
281
  clear_memory()
282
 
 
 
283
  if use_controlnet:
284
  if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None:
285
  raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.")
286
 
287
  if mode == "Single Image":
288
- control_images = [control_image]
289
  else:
290
- control_images = [Image.open(img).convert("RGB") for img in batch_images_input]
 
 
291
 
292
  preprocessed_images = []
293
  for img in tqdm(control_images, desc="Preprocessing images"):
@@ -300,10 +307,16 @@ def generate_images_with_progress(prompt, negative_prompt, batch_count, use_cont
300
  else: # Reference
301
  preprocessed_images.append(img)
302
 
303
- if controlnet_type == "Reference":
304
- images = process_image_batch(preprocessed_images, reference_pipe, prompt, negative_prompt, progress)
305
- else:
306
- images = process_image_batch(preprocessed_images, controlnet_pipe, prompt, negative_prompt, progress)
 
 
 
 
 
 
307
  else:
308
  if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
309
  controlnet_pipe.to("cpu")
@@ -331,13 +344,27 @@ def generate_images_with_progress(prompt, negative_prompt, batch_count, use_cont
331
  clear_memory()
332
  return images
333
 
334
-
335
-
336
  # Function to extract PNG metadata
337
  def extract_png_info(image_path):
338
  metadata = image_path.info # This is a dictionary containing key-value pairs of metadata
339
  return metadata
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  # Define the Gradio interface
342
  with gr.Blocks() as demo:
343
  gr.Markdown("# Image Generation with Custom Prompts and Styles")
@@ -356,6 +383,7 @@ with gr.Blocks() as demo:
356
  controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type")
357
  controlnet_status = gr.Textbox(label="ControlNet Status", value="", interactive=False)
358
  mode = gr.Radio(choices=["Single Image", "Batch"], label="Mode", value="Single Image")
 
359
 
360
  with gr.Tabs() as tabs:
361
  with gr.TabItem("Single Image"):
@@ -368,6 +396,18 @@ with gr.Blocks() as demo:
368
  png_image = gr.Image(label="Upload PNG Image", type='pil')
369
  metadata_output = gr.JSON(label="PNG Metadata")
370
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  with gr.Column(scale=2):
372
  style_images_gallery = gr.Gallery(
373
  label="Choose a Style",
@@ -398,10 +438,37 @@ with gr.Blocks() as demo:
398
 
399
  controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status)
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  generate_button = gr.Button("Generate Images")
402
  generate_button.click(
403
- generate_images_with_progress,
404
- inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input],
405
  outputs=gallery
406
  )
407
 
@@ -419,4 +486,4 @@ with gr.Blocks() as demo:
419
  if __name__ == "__main__":
420
  # Your Gradio interface setup here
421
  demo.launch(auth=("roland", "roland"), debug=True)
422
- clear_memory()
 
12
  import random
13
  import gc
14
 
 
15
  def clear_memory():
16
  gc.collect()
17
  if torch.cuda.is_available():
 
207
 
208
  # Preprocessing functions for each ControlNet type
209
  def preprocess_canny(image):
210
+ if isinstance(image, str):
211
+ image = Image.open(image).convert("RGB")
212
  if isinstance(image, Image.Image):
213
  image = np.array(image)
214
  if image.dtype != np.uint8:
 
218
  image = np.concatenate([image, image, image], axis=2)
219
  return Image.fromarray(image)
220
 
 
221
  def preprocess_depth(image, target_size=(1024, 1024)):
222
+ if isinstance(image, str):
223
+ image = Image.open(image).convert("RGB")
224
  if isinstance(image, Image.Image):
225
  img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
226
  else:
 
236
  return Image.fromarray(depth_img_resized)
237
 
238
  def preprocess_openpose(image):
239
+ if isinstance(image, str):
240
+ image = Image.open(image).convert("RGB")
241
  if isinstance(image, Image.Image):
242
  image = np.array(image)
243
  image = openpose_processor(image, hand_and_face=False, output_type='cv2')
 
277
  clear_memory() # Clear memory after each batch
278
  return all_processed_images
279
 
 
280
  # Define the function to generate images
281
+ def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_images, progress=gr.Progress(track_tqdm=True)):
282
  global controlnet_pipe, pipe, reference_pipe
283
 
284
  clear_memory()
285
 
286
+ chunk_size = 2 # Adjust this number based on your memory capacity
287
+
288
  if use_controlnet:
289
  if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None:
290
  raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.")
291
 
292
  if mode == "Single Image":
293
+ control_images = [control_images] if isinstance(control_images, Image.Image) else control_images
294
  else:
295
+ if not control_images:
296
+ raise ValueError("No images provided for batch processing.")
297
+ control_images = [Image.open(img).convert("RGB") if isinstance(img, str) else img for img in control_images]
298
 
299
  preprocessed_images = []
300
  for img in tqdm(control_images, desc="Preprocessing images"):
 
307
  else: # Reference
308
  preprocessed_images.append(img)
309
 
310
+ images = []
311
+ for i in range(0, len(preprocessed_images), chunk_size):
312
+ chunk = preprocessed_images[i:i+chunk_size]
313
+ if controlnet_type == "Reference":
314
+ images_chunk = process_image_batch(chunk, reference_pipe, prompt, negative_prompt, progress)
315
+ else:
316
+ images_chunk = process_image_batch(chunk, controlnet_pipe, prompt, negative_prompt, progress)
317
+ images.extend(images_chunk)
318
+ clear_memory()
319
+
320
  else:
321
  if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
322
  controlnet_pipe.to("cpu")
 
344
  clear_memory()
345
  return images
346
 
 
 
347
  # Function to extract PNG metadata
348
  def extract_png_info(image_path):
349
  metadata = image_path.info # This is a dictionary containing key-value pairs of metadata
350
  return metadata
351
 
352
+ # Load images from the specified folder
353
+ def load_images_from_folder(folder_path):
354
+ images = []
355
+ for filename in os.listdir(folder_path):
356
+ if filename.endswith(('.png', '.jpg', '.jpeg')):
357
+ img_path = os.path.join(folder_path, filename)
358
+ img = Image.open(img_path).convert("RGB")
359
+ images.append((filename, img))
360
+ return images
361
+
362
+ # Folder path where images are stored
363
+ image_folder_path = "control" # Update this path to your folder
364
+
365
+ # Load images from folder
366
+ loaded_images = load_images_from_folder(image_folder_path)
367
+
368
  # Define the Gradio interface
369
  with gr.Blocks() as demo:
370
  gr.Markdown("# Image Generation with Custom Prompts and Styles")
 
383
  controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type")
384
  controlnet_status = gr.Textbox(label="ControlNet Status", value="", interactive=False)
385
  mode = gr.Radio(choices=["Single Image", "Batch"], label="Mode", value="Single Image")
386
+ use_control_folder = gr.Checkbox(label="Use Control Folder for Batch Processing", value=False)
387
 
388
  with gr.Tabs() as tabs:
389
  with gr.TabItem("Single Image"):
 
396
  png_image = gr.Image(label="Upload PNG Image", type='pil')
397
  metadata_output = gr.JSON(label="PNG Metadata")
398
 
399
+ with gr.TabItem("Select from Folder"):
400
+ folder_images_gallery = gr.Gallery(
401
+ label="Images from Folder",
402
+ value=[img[1] for img in loaded_images],
403
+ interactive=True,
404
+ elem_id="folder-gallery",
405
+ columns=5,
406
+ object_fit="contain",
407
+ height=235,
408
+ allow_preview=False
409
+ )
410
+
411
  with gr.Column(scale=2):
412
  style_images_gallery = gr.Gallery(
413
  label="Choose a Style",
 
438
 
439
  controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status)
440
 
441
+ selected_folder_images = gr.State(value=[])
442
+
443
+ def select_folder_image(evt: gr.SelectData):
444
+ folder_image_names = [img[0] for img in loaded_images]
445
+ if evt.index < 0 or evt.index >= len(folder_image_names):
446
+ raise ValueError(f"Invalid index: {evt.index}")
447
+ selected_image_name = folder_image_names[evt.index]
448
+ selected_image = next(img for img in loaded_images if img[0] == selected_image_name)
449
+ current_images = selected_folder_images.value
450
+ if selected_image not in current_images:
451
+ current_images.append(selected_image)
452
+ return current_images
453
+
454
+ folder_images_gallery.select(fn=select_folder_image, inputs=[], outputs=selected_folder_images)
455
+
456
+ def generate_images_with_folder_images(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, use_control_folder, selected_folder_images, batch_images_input, progress=gr.Progress(track_tqdm=True)):
457
+ if mode == "Batch":
458
+ if use_control_folder:
459
+ selected_images = [img[1] for img in loaded_images]
460
+ else:
461
+ if not batch_images_input:
462
+ raise ValueError("No images uploaded for batch processing.")
463
+ selected_images = [Image.open(img).convert("RGB") for img in batch_images_input]
464
+ else:
465
+ selected_images = [img[1] for img in selected_folder_images]
466
+ return generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, selected_images, progress)
467
+
468
  generate_button = gr.Button("Generate Images")
469
  generate_button.click(
470
+ generate_images_with_folder_images,
471
+ inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, use_control_folder, selected_folder_images, batch_images_input],
472
  outputs=gallery
473
  )
474
 
 
486
  if __name__ == "__main__":
487
  # Your Gradio interface setup here
488
  demo.launch(auth=("roland", "roland"), debug=True)
489
+ clear_memory()