Vijish commited on
Commit
687bf5b
Β·
verified Β·
1 Parent(s): 62c119b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -377,6 +377,31 @@ image_folder_path = "control" # Update this path to your folder
377
  # Load images from folder
378
  loaded_images = load_images_from_folder(image_folder_path)
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  # Define the Gradio interface
381
  with gr.Blocks() as demo:
382
  gr.Markdown("# Image Generation with Custom Prompts and Styles")
@@ -433,7 +458,7 @@ with gr.Blocks() as demo:
433
  height=235,
434
  allow_preview=False
435
  )
436
- gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=820)
437
 
438
  selected_style = gr.State(value="Anime Studio Dance")
439
 
@@ -497,8 +522,16 @@ with gr.Blocks() as demo:
497
  outputs=metadata_output
498
  )
499
 
 
 
 
 
 
 
 
500
  with gr.Row():
501
  generate_button
 
502
 
503
  # At the end of your script:
504
  if __name__ == "__main__":
 
377
  # Load images from folder
378
  loaded_images = load_images_from_folder(image_folder_path)
379
 
380
+ # Restart function to clear memory and reinitialize models
381
+ def restart():
382
+ global controlnet_pipe, reference_pipe, pipe, current_controlnet_type, controlnet_models
383
+ clear_memory()
384
+ controlnet_pipe = None
385
+ reference_pipe = None
386
+ pipe = None
387
+ current_controlnet_type = None
388
+ controlnet_models = {
389
+ "Canny": None,
390
+ "Depth": None,
391
+ "OpenPose": None,
392
+ "Reference": None
393
+ }
394
+ load_base_model()
395
+ return "Restarted successfully!"
396
+
397
+ def load_base_model():
398
+ global pipe
399
+ model = "aicollective1/aicollective"
400
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
401
+ pipe.to("cuda")
402
+
403
+ load_base_model()
404
+
405
  # Define the Gradio interface
406
  with gr.Blocks() as demo:
407
  gr.Markdown("# Image Generation with Custom Prompts and Styles")
 
458
  height=235,
459
  allow_preview=False
460
  )
461
+ gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=850)
462
 
463
  selected_style = gr.State(value="Anime Studio Dance")
464
 
 
522
  outputs=metadata_output
523
  )
524
 
525
+ restart_button = gr.Button("Restart")
526
+ restart_button.click(
527
+ fn=restart,
528
+ inputs=[],
529
+ outputs=controlnet_status
530
+ )
531
+
532
  with gr.Row():
533
  generate_button
534
+ restart_button
535
 
536
  # At the end of your script:
537
  if __name__ == "__main__":