Spaces:
Build error
Build error
import gradio as gr | |
import json | |
from .class_configuration_file import ConfigurationFile | |
from .class_source_model import SourceModel | |
from .class_folders import Folders | |
from .class_basic_training import BasicTraining | |
from .class_advanced_training import AdvancedTraining | |
from .class_sample_images import SampleImages | |
from library.dreambooth_folder_creation_gui import ( | |
gradio_dreambooth_folder_creation_tab, | |
) | |
from .common_gui import color_aug_changed | |
class Dreambooth: | |
def __init__( | |
self, | |
headless: bool = False, | |
): | |
self.headless = headless | |
self.dummy_db_true = gr.Label(value=True, visible=False) | |
self.dummy_db_false = gr.Label(value=False, visible=False) | |
self.dummy_headless = gr.Label(value=headless, visible=False) | |
gr.Markdown('Train a custom model using kohya dreambooth python code...') | |
# Setup Configuration Files Gradio | |
self.config = ConfigurationFile(headless) | |
self.source_model = SourceModel(headless=headless) | |
with gr.Tab('Folders'): | |
self.folders = Folders(headless=headless) | |
with gr.Tab('Parameters'): | |
self.basic_training = BasicTraining( | |
learning_rate_value='1e-5', | |
lr_scheduler_value='cosine', | |
lr_warmup_value='10', | |
) | |
with gr.Accordion('Advanced Configuration', open=False): | |
self.advanced_training = AdvancedTraining(headless=headless) | |
self.advanced_training.color_aug.change( | |
color_aug_changed, | |
inputs=[self.advanced_training.color_aug], | |
outputs=[self.basic_training.cache_latents], | |
) | |
self.sample = SampleImages() | |
with gr.Tab('Tools'): | |
gr.Markdown( | |
'This section provide Dreambooth tools to help setup your dataset...' | |
) | |
gradio_dreambooth_folder_creation_tab( | |
train_data_dir_input=self.folders.train_data_dir, | |
reg_data_dir_input=self.folders.reg_data_dir, | |
output_dir_input=self.folders.output_dir, | |
logging_dir_input=self.folders.logging_dir, | |
headless=headless, | |
) | |
def save_to_json(self, filepath): | |
def serialize(obj): | |
if isinstance(obj, gr.inputs.Input): | |
return obj.get() | |
if isinstance(obj, (bool, int, float, str)): | |
return obj | |
if isinstance(obj, dict): | |
return {k: serialize(v) for k, v in obj.items()} | |
if hasattr(obj, "__dict__"): | |
return serialize(vars(obj)) | |
return str(obj) # Fallback for objects that can't be serialized | |
try: | |
with open(filepath, 'w') as outfile: | |
print(serialize(vars(self))) | |
json.dump(serialize(vars(self)), outfile) | |
except Exception as e: | |
print(f"Error saving to JSON: {str(e)}") | |
def load_from_json(self, filepath): | |
def deserialize(key, value): | |
if hasattr(self, key): | |
attr = getattr(self, key) | |
if isinstance(attr, gr.inputs.Input): | |
attr.set(value) | |
elif hasattr(attr, "__dict__"): | |
for k, v in value.items(): | |
deserialize(k, v) | |
else: | |
setattr(self, key, value) | |
else: | |
print(f"Warning: {key} not found in the object's attributes.") | |
try: | |
with open(filepath) as json_file: | |
data = json.load(json_file) | |
for key, value in data.items(): | |
deserialize(key, value) | |
except FileNotFoundError: | |
print(f"Error: The file {filepath} was not found.") | |
except json.JSONDecodeError: | |
print(f"Error: The file {filepath} could not be decoded as JSON.") | |
except Exception as e: | |
print(f"Error loading from JSON: {str(e)}") |