""" Default values taken from https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml when applicable. """ import argparse def _get_model_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--variant", type=str, default=None, help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument( "--cast_dit", action="store_true", help="If we should cast DiT params to a lower precision.", ) parser.add_argument( "--compile_dit", action="store_true", help="If we should compile the DiT.", ) def _get_dataset_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--data_root", type=str, default=None, help=("A folder containing the training data."), ) parser.add_argument( "--caption_dropout", type=float, default=None, help=("Probability to drop out captions randomly."), ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--pin_memory", action="store_true", help="Whether or not to use the pinned memory setting in pytorch dataloader.", ) def _get_validation_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--validation_prompt", type=str, default=None, help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", ) parser.add_argument( "--validation_images", type=str, default=None, help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", ) parser.add_argument( "--validation_prompt_separator", type=str, default=":::", help="String that separates multiple validation prompts", ) parser.add_argument( "--num_validation_videos", type=int, default=1, help="Number of videos that should be generated during validation per `validation_prompt`.", ) parser.add_argument( "--validation_epochs", type=int, default=50, help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", ) parser.add_argument( "--enable_slicing", action="store_true", default=False, help="Whether or not to use VAE slicing for saving memory.", ) parser.add_argument( "--enable_tiling", action="store_true", default=False, help="Whether or not to use VAE tiling for saving memory.", ) parser.add_argument( "--enable_model_cpu_offload", action="store_true", default=False, help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", ) parser.add_argument( "--fps", type=int, default=30, help="FPS to use when serializing the output videos.", ) parser.add_argument( "--height", type=int, default=480, ) parser.add_argument( "--width", type=int, default=848, ) def _get_training_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") parser.add_argument( "--lora_alpha", type=int, default=16, help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", ) parser.add_argument( "--target_modules", nargs="+", type=str, default=["to_k", "to_q", "to_v", "to_out.0"], help="Target modules to train LoRA for.", ) parser.add_argument( "--output_dir", type=str, default="mochi-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=2e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_warmup_steps", type=int, default=200, help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--checkpointing_steps", type=int, default=None, ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, ) def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--optimizer", type=lambda s: s.lower(), default="adam", choices=["adam", "adamw"], help=("The optimizer type to use."), ) parser.add_argument( "--weight_decay", type=float, default=0.01, help="Weight decay to use for optimizer.", ) def _get_configuration_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") parser.add_argument( "--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.", ) parser.add_argument( "--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.", ) parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.") def get_args(): parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") _get_model_args(parser) _get_dataset_args(parser) _get_training_args(parser) _get_validation_args(parser) _get_optimizer_args(parser) _get_configuration_args(parser) return parser.parse_args()