from dataclasses import dataclass from typing import Optional, List @dataclass class TrainingConfig: seed: int = 42 wandb_project: str = "univa-denoiser" wandb_name: str = "default_config" output_dir: str = "./output" logging_dir: str = "./logs" gradient_accumulation_steps: int = 1 learning_rate: float = 1e-5 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-8 adam_weight_decay: float = 1e-2 mixed_precision: str = "bf16" report_to: str = "wandb" gradient_checkpointing: bool = False num_train_epochs: int = 1 max_train_steps: Optional[int] = None lr_scheduler: str = "constant" lr_warmup_steps: int = 0 lr_num_cycles: int = 1 lr_power: float = 1.0 resume_from_checkpoint: Optional[str] = None weighting_scheme: Optional[str] = ( "logit_normal" # ["sigma_sqrt", "logit_normal", "mode", "cosmap", "null"] ) logit_mean: float = 0.0 logit_std: float = 1.0 mode_scale: float = 1.29 max_grad_norm: float = 1.0 checkpointing_steps: int = 100 checkpoints_total_limit: Optional[int] = 500 drop_condition_rate: float = 0.0 drop_t5_rate: float = 1.0 validation_steps: int = 100 num_validation_images: int = 1 noise_reference_images: bool = False mask_weight_type: Optional[str] = None # ['log', 'exp'] sigmas_as_weight: bool = False # Used in Flux discrete_timestep: bool = True # Used in Flux optimizer: str = 'adamw' # ['adamw', 'prodigy'] prodigy_use_bias_correction: bool = True prodigy_safeguard_warmup: bool = True prodigy_decouple: bool = True prodigy_beta3: Optional[float] = None prodigy_d_coef: float = 1.0 profile_out_dir: Optional[str] = None ema_deepspeed_config_file: Optional[str] = None ema_update_freq: int = 1 ema_decay: float = 0.99 @dataclass class DatasetConfig: dataset_type: str data_txt: str batch_size: int = 16 num_workers: int = 4 height: int = 512 width: int = 512 min_pixels: int = 448*448 max_pixels: int = 448*448 anyres: str = 'any_1ratio' ocr_enhancer: bool = False random_data: bool = False padding_side: str = 'right' validation_t2i_prompt: Optional[str] = None validation_it2i_prompt: Optional[str] = None validation_image_path: Optional[str] = None pin_memory: bool = True validation_iit2i_prompt: Optional[str] = None validation_iit2i_path: Optional[List[str]] = None validation_REFiit2i_prompt: Optional[str] = None validation_REFiit2i_path: Optional[List[str]] = None validation_cannyt2i_prompt: Optional[str] = None validation_cannyt2i_path: Optional[str] = None validation_poset2i_prompt: Optional[str] = None validation_poset2i_path: Optional[str] = None validation_it2pose_prompt: Optional[str] = None validation_it2pose_path: Optional[str] = None validation_it2canny_prompt: Optional[str] = None validation_it2canny_path: Optional[str] = None validation_NIKEit2i_prompt: Optional[str] = None validation_NIKEit2i_path: Optional[str] = None validation_TRANSFERit2i_prompt: Optional[str] = None validation_TRANSFERit2i_path: Optional[str] = None validation_EXTRACTit2i_prompt: Optional[str] = None validation_EXTRACTit2i_path: Optional[str] = None validation_TRYONit2i_prompt: Optional[str] = None validation_TRYONit2i_path: Optional[str] = None validation_REPLACEit2i_prompt: Optional[str] = None validation_REPLACEit2i_path: Optional[str] = None validation_DETit2i_prompt: Optional[str] = None validation_DETit2i_path: Optional[str] = None validation_SEGit2i_prompt: Optional[str] = None validation_SEGit2i_path: Optional[str] = None @dataclass class ModelConfig: pretrained_lvlm_name_or_path: str pretrained_denoiser_name_or_path: str pretrained_siglip_name_or_path: Optional[str] = None train_vision_tower_mm_projector: bool = False guidance_scale: float = 1.0 # Used in Flux tune_mlp1_only: bool = False pretrained_mlp1_path: Optional[str] = None with_tune_mlp2: bool = False only_tune_mlp2: bool = False pretrained_mlp2_path: Optional[str] = None only_tune_image_branch: bool = True # Used in SD3 with_tune_mlp3: bool = False only_tune_mlp3: bool = False pretrained_mlp3_path: Optional[str] = None flux_train_layer_idx: Optional[list] = None with_tune_siglip_mlp: bool = False only_tune_siglip_mlp: bool = False pretrained_siglip_mlp_path: Optional[str] = None joint_ref_feature: bool = False joint_ref_feature_as_condition: bool = False only_use_t5: bool = False vlm_residual_image_factor: float = 0.0 vae_fp32: bool = True compile_flux: bool = False compile_qwen2p5vl: bool = False ema_pretrained_lvlm_name_or_path: Optional[str] = None @dataclass class UnivaTrainingDenoiseConfig: training_config: TrainingConfig dataset_config: DatasetConfig model_config: ModelConfig