Spaces:
Build error
Build error
import gradio as gr | |
import os | |
class BasicTraining: | |
def __init__( | |
self, | |
learning_rate_value='1e-6', | |
lr_scheduler_value='constant', | |
lr_warmup_value='0', | |
finetuning: bool = False, | |
): | |
self.learning_rate_value = learning_rate_value | |
self.lr_scheduler_value = lr_scheduler_value | |
self.lr_warmup_value = lr_warmup_value | |
self.finetuning = finetuning | |
with gr.Row(): | |
self.train_batch_size = gr.Slider( | |
minimum=1, | |
maximum=64, | |
label='Train batch size', | |
value=1, | |
step=1, | |
) | |
self.epoch = gr.Number(label='Epoch', value=1, precision=0) | |
self.save_every_n_epochs = gr.Number( | |
label='Save every N epochs', value=1, precision=0 | |
) | |
self.caption_extension = gr.Textbox( | |
label='Caption Extension', | |
placeholder='(Optional) Extension for caption files. default: .caption', | |
) | |
with gr.Row(): | |
self.mixed_precision = gr.Dropdown( | |
label='Mixed precision', | |
choices=[ | |
'no', | |
'fp16', | |
'bf16', | |
], | |
value='fp16', | |
) | |
self.save_precision = gr.Dropdown( | |
label='Save precision', | |
choices=[ | |
'float', | |
'fp16', | |
'bf16', | |
], | |
value='fp16', | |
) | |
self.num_cpu_threads_per_process = gr.Slider( | |
minimum=1, | |
maximum=os.cpu_count(), | |
step=1, | |
label='Number of CPU threads per core', | |
value=2, | |
) | |
self.seed = gr.Textbox( | |
label='Seed', placeholder='(Optional) eg:1234' | |
) | |
self.cache_latents = gr.Checkbox(label='Cache latents', value=True) | |
self.cache_latents_to_disk = gr.Checkbox( | |
label='Cache latents to disk', value=False | |
) | |
with gr.Row(): | |
self.learning_rate = gr.Number( | |
label='Learning rate', value=learning_rate_value | |
) | |
self.lr_scheduler = gr.Dropdown( | |
label='LR Scheduler', | |
choices=[ | |
'adafactor', | |
'constant', | |
'constant_with_warmup', | |
'cosine', | |
'cosine_with_restarts', | |
'linear', | |
'polynomial', | |
], | |
value=lr_scheduler_value, | |
) | |
self.lr_warmup = gr.Slider( | |
label='LR warmup (% of steps)', | |
value=lr_warmup_value, | |
minimum=0, | |
maximum=100, | |
step=1, | |
) | |
self.optimizer = gr.Dropdown( | |
label='Optimizer', | |
choices=[ | |
'AdamW', | |
'AdamW8bit', | |
'Adafactor', | |
'DAdaptation', | |
'DAdaptAdaGrad', | |
'DAdaptAdam', | |
'DAdaptAdan', | |
'DAdaptAdanIP', | |
'DAdaptAdamPreprint', | |
'DAdaptLion', | |
'DAdaptSGD', | |
'Lion', | |
'Lion8bit', | |
'Prodigy', | |
'SGDNesterov', | |
'SGDNesterov8bit', | |
], | |
value='AdamW8bit', | |
interactive=True, | |
) | |
with gr.Row(): | |
self.optimizer_args = gr.Textbox( | |
label='Optimizer extra arguments', | |
placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True', | |
) | |
with gr.Row(visible=not finetuning): | |
self.max_resolution = gr.Textbox( | |
label='Max resolution', | |
value='512,512', | |
placeholder='512,512', | |
) | |
self.stop_text_encoder_training = gr.Slider( | |
minimum=-1, | |
maximum=100, | |
value=0, | |
step=1, | |
label='Stop text encoder training', | |
) | |
self.enable_bucket = gr.Checkbox(label='Enable buckets', value=True) |