|
import gradio as gr |
|
import json |
|
from .class_configuration_file import ConfigurationFile |
|
from .class_source_model import SourceModel |
|
from .class_folders import Folders |
|
from .class_basic_training import BasicTraining |
|
from .class_advanced_training import AdvancedTraining |
|
from .class_sample_images import SampleImages |
|
from library.dreambooth_folder_creation_gui import ( |
|
gradio_dreambooth_folder_creation_tab, |
|
) |
|
from .common_gui import color_aug_changed |
|
|
|
class Dreambooth: |
|
def __init__( |
|
self, |
|
headless: bool = False, |
|
): |
|
self.headless = headless |
|
self.dummy_db_true = gr.Label(value=True, visible=False) |
|
self.dummy_db_false = gr.Label(value=False, visible=False) |
|
self.dummy_headless = gr.Label(value=headless, visible=False) |
|
|
|
gr.Markdown('Train a custom model using kohya dreambooth python code...') |
|
|
|
|
|
self.config = ConfigurationFile(headless) |
|
|
|
self.source_model = SourceModel(headless=headless) |
|
|
|
with gr.Tab('Folders'): |
|
self.folders = Folders(headless=headless) |
|
with gr.Tab('Parameters'): |
|
self.basic_training = BasicTraining( |
|
learning_rate_value='1e-5', |
|
lr_scheduler_value='cosine', |
|
lr_warmup_value='10', |
|
) |
|
self.full_bf16 = gr.Checkbox( |
|
label='Full bf16', value = False |
|
) |
|
with gr.Accordion('Advanced Configuration', open=False): |
|
self.advanced_training = AdvancedTraining(headless=headless) |
|
self.advanced_training.color_aug.change( |
|
color_aug_changed, |
|
inputs=[self.advanced_training.color_aug], |
|
outputs=[self.basic_training.cache_latents], |
|
) |
|
|
|
self.sample = SampleImages() |
|
|
|
with gr.Tab('Tools'): |
|
gr.Markdown( |
|
'This section provide Dreambooth tools to help setup your dataset...' |
|
) |
|
gradio_dreambooth_folder_creation_tab( |
|
train_data_dir_input=self.folders.train_data_dir, |
|
reg_data_dir_input=self.folders.reg_data_dir, |
|
output_dir_input=self.folders.output_dir, |
|
logging_dir_input=self.folders.logging_dir, |
|
headless=headless, |
|
) |
|
|
|
def save_to_json(self, filepath): |
|
def serialize(obj): |
|
if isinstance(obj, gr.inputs.Input): |
|
return obj.get() |
|
if isinstance(obj, (bool, int, float, str)): |
|
return obj |
|
if isinstance(obj, dict): |
|
return {k: serialize(v) for k, v in obj.items()} |
|
if hasattr(obj, "__dict__"): |
|
return serialize(vars(obj)) |
|
return str(obj) |
|
|
|
try: |
|
with open(filepath, 'w') as outfile: |
|
print(serialize(vars(self))) |
|
json.dump(serialize(vars(self)), outfile) |
|
except Exception as e: |
|
print(f"Error saving to JSON: {str(e)}") |
|
|
|
def load_from_json(self, filepath): |
|
def deserialize(key, value): |
|
if hasattr(self, key): |
|
attr = getattr(self, key) |
|
if isinstance(attr, gr.inputs.Input): |
|
attr.set(value) |
|
elif hasattr(attr, "__dict__"): |
|
for k, v in value.items(): |
|
deserialize(k, v) |
|
else: |
|
setattr(self, key, value) |
|
else: |
|
print(f"Warning: {key} not found in the object's attributes.") |
|
|
|
try: |
|
with open(filepath) as json_file: |
|
data = json.load(json_file) |
|
for key, value in data.items(): |
|
deserialize(key, value) |
|
except FileNotFoundError: |
|
print(f"Error: The file {filepath} was not found.") |
|
except json.JSONDecodeError: |
|
print(f"Error: The file {filepath} could not be decoded as JSON.") |
|
except Exception as e: |
|
print(f"Error loading from JSON: {str(e)}") |