multimodalart HF staff commited on
Commit
865fdcc
·
verified ·
1 Parent(s): 44ce793

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -20,6 +20,8 @@ import zipfile
20
 
21
  MAX_IMAGES = 150
22
 
 
 
23
  training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/ba28006f8b2a0f7ec3b6784695790422b4f80a97/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
24
  subprocess.run(['wget', '-N', training_script_url])
25
  orchestrator_script_url = "https://huggingface.co/datasets/multimodalart/lora-ease-helper/raw/main/script.py"
@@ -114,7 +116,17 @@ def load_captioning(uploaded_images, option):
114
 
115
  def check_removed_and_restart(images):
116
  visible = len(images) > 1 if images is not None else False
117
- return [gr.update(visible=visible) for _ in range(3)]
 
 
 
 
 
 
 
 
 
 
118
 
119
  def make_options_visible(option):
120
  if (option == "object") or (option == "face"):
@@ -388,9 +400,11 @@ def start_training_og(
388
  enable_xformers_memory_efficient_attention,
389
  adam_beta1,
390
  adam_beta2,
 
391
  prodigy_beta3,
392
  prodigy_decouple,
393
  adam_weight_decay,
 
394
  adam_weight_decay_text_encoder,
395
  adam_epsilon,
396
  prodigy_use_bias_correction,
@@ -404,10 +418,14 @@ def start_training_og(
404
  dataloader_num_workers,
405
  local_rank,
406
  dataset_folder,
407
- progress = gr.Progress(track_tqdm=True)
 
408
  ):
 
 
409
  slugged_lora_name = slugify(lora_name)
410
- commands = ["--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
 
411
  "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
412
  f"--instance_prompt={concept_sentence}",
413
  f"--dataset_name=./{dataset_folder}",
@@ -433,9 +451,7 @@ def start_training_og(
433
  f"--prior_loss_weight={prior_loss_weight}",
434
  f"--num_new_tokens_per_abstraction={int(num_new_tokens_per_abstraction)}",
435
  f"--num_train_epochs={int(num_train_epochs)}",
436
- f"--prodigy_beta3={prodigy_beta3}",
437
  f"--adam_weight_decay={adam_weight_decay}",
438
- f"--adam_weight_decay_text_encoder={adam_weight_decay_text_encoder}",
439
  f"--adam_epsilon={adam_epsilon}",
440
  f"--prodigy_decouple={prodigy_decouple}",
441
  f"--prodigy_use_bias_correction={prodigy_use_bias_correction}",
@@ -474,11 +490,16 @@ def start_training_og(
474
  for image in class_images:
475
  shutil.copy(image, class_folder)
476
  commands.append(f"--class_data_dir={class_folder}")
477
-
 
 
 
478
  from train_dreambooth_lora_sdxl_advanced import main as train_main, parse_args as parse_train_args
479
  args = parse_train_args(commands)
 
480
  train_main(args)
481
- return "ok!"
 
482
 
483
  @spaces.GPU(enable_queue=True)
484
  def run_captioning(*inputs):
@@ -948,7 +969,7 @@ with gr.Blocks(css=css, theme=theme) as demo:
948
  images.change(
949
  check_removed_and_restart,
950
  inputs=[images],
951
- outputs=[captioning_area, advanced, cost_estimation],
952
  queue=False
953
  )
954
  training_option.change(
@@ -969,7 +990,7 @@ with gr.Blocks(css=css, theme=theme) as demo:
969
  outputs=dataset_folder,
970
  queue=False
971
  ).then(
972
- fn=start_training,
973
  inputs=[
974
  lora_name,
975
  training_option,
 
20
 
21
  MAX_IMAGES = 150
22
 
23
+ is_spaces = True if os.environ.get('SPACE_ID') else False
24
+
25
  training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/ba28006f8b2a0f7ec3b6784695790422b4f80a97/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
26
  subprocess.run(['wget', '-N', training_script_url])
27
  orchestrator_script_url = "https://huggingface.co/datasets/multimodalart/lora-ease-helper/raw/main/script.py"
 
116
 
117
  def check_removed_and_restart(images):
118
  visible = len(images) > 1 if images is not None else False
119
+ if(is_spaces):
120
+ captioning_area = gr.update(visible=visible)
121
+ advanced = gr.update(visible=visible)
122
+ cost_estimation = gr.update(visible=visible)
123
+ start = gr.update(visible=False)
124
+ else:
125
+ captioning_area = gr.update(visible=visible)
126
+ advanced = gr.update(visible=visible)
127
+ cost_estimation = gr.update(visible=False)
128
+ start = gr.update(visible=True)
129
+ return captioning_area, advanced,cost_estimation, start
130
 
131
  def make_options_visible(option):
132
  if (option == "object") or (option == "face"):
 
400
  enable_xformers_memory_efficient_attention,
401
  adam_beta1,
402
  adam_beta2,
403
+ use_prodigy_beta3,
404
  prodigy_beta3,
405
  prodigy_decouple,
406
  adam_weight_decay,
407
+ use_adam_weight_decay_text_encoder,
408
  adam_weight_decay_text_encoder,
409
  adam_epsilon,
410
  prodigy_use_bias_correction,
 
418
  dataloader_num_workers,
419
  local_rank,
420
  dataset_folder,
421
+ token,
422
+ #progress = gr.Progress(track_tqdm=True)
423
  ):
424
+ if not lora_name:
425
+ raise gr.Error("You forgot to insert your LoRA name!")
426
  slugged_lora_name = slugify(lora_name)
427
+ commands = [
428
+ "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
429
  "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
430
  f"--instance_prompt={concept_sentence}",
431
  f"--dataset_name=./{dataset_folder}",
 
451
  f"--prior_loss_weight={prior_loss_weight}",
452
  f"--num_new_tokens_per_abstraction={int(num_new_tokens_per_abstraction)}",
453
  f"--num_train_epochs={int(num_train_epochs)}",
 
454
  f"--adam_weight_decay={adam_weight_decay}",
 
455
  f"--adam_epsilon={adam_epsilon}",
456
  f"--prodigy_decouple={prodigy_decouple}",
457
  f"--prodigy_use_bias_correction={prodigy_use_bias_correction}",
 
490
  for image in class_images:
491
  shutil.copy(image, class_folder)
492
  commands.append(f"--class_data_dir={class_folder}")
493
+ if use_prodigy_beta3:
494
+ commands.append(f"--prodigy_beta3={prodigy_beta3}")
495
+ if use_adam_weight_decay_text_encoder:
496
+ commands.append(f"--adam_weight_decay_text_encoder={adam_weight_decay_text_encoder}")
497
  from train_dreambooth_lora_sdxl_advanced import main as train_main, parse_args as parse_train_args
498
  args = parse_train_args(commands)
499
+
500
  train_main(args)
501
+
502
+ return f"Your model has finished training and has been saved to the `{slugged_lora_name}` folder"
503
 
504
  @spaces.GPU(enable_queue=True)
505
  def run_captioning(*inputs):
 
969
  images.change(
970
  check_removed_and_restart,
971
  inputs=[images],
972
+ outputs=[captioning_area, advanced, cost_estimation, start],
973
  queue=False
974
  )
975
  training_option.change(
 
990
  outputs=dataset_folder,
991
  queue=False
992
  ).then(
993
+ fn=start_training if is_spaces else start_training_og,
994
  inputs=[
995
  lora_name,
996
  training_option,