diff --git a/finetrainers/__init__.py b/finetrainers/__init__.py deleted file mode 100644 index 7da2391e864af71edf8b826d1f1263d5c8f1afe5..0000000000000000000000000000000000000000 --- a/finetrainers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .args import BaseArgs -from .config import ModelType, TrainingType -from .logging import get_logger -from .models import ModelSpecification -from .trainer import SFTTrainer diff --git a/finetrainers/args.py b/finetrainers/args.py deleted file mode 100644 index c01736c31bcfa8479fe33ab31a3995082c301d05..0000000000000000000000000000000000000000 --- a/finetrainers/args.py +++ /dev/null @@ -1,865 +0,0 @@ -import argparse -import os -import pathlib -import sys -from typing import Any, Callable, Dict, List, Optional - -import torch - -from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType -from .logging import get_logger -from .parallel import ParallelBackendEnum -from .utils import get_non_null_items - - -logger = get_logger() - - -class BaseArgs: - r""" - The arguments for the finetrainers training script. - - For helpful information about arguments, run `python train.py --help`. - - TODO(aryan): add `python train.py --recommend_configs --model_name ` to recommend - good training configs for a model after extensive testing. - TODO(aryan): add `python train.py --memory_requirements --model_name ` to show - memory requirements per model, per training type with sensible training settings. - - PARALLEL ARGUMENTS - ------------------ - parallel_backend (`str`, defaults to `accelerate`): - The parallel backend to use for training. Choose between ['accelerate', 'ptd']. - pp_degree (`int`, defaults to `1`): - The degree of pipeline parallelism. - dp_degree (`int`, defaults to `1`): - The degree of data parallelism (number of model replicas). - dp_shards (`int`, defaults to `-1`): - The number of data parallel shards (number of model partitions). - cp_degree (`int`, defaults to `1`): - The degree of context parallelism. - - MODEL ARGUMENTS - --------------- - model_name (`str`): - Name of model to train. To get a list of models, run `python train.py --list_models`. - pretrained_model_name_or_path (`str`): - Path to pretrained model or model identifier from https://huggingface.co/models. The model should be - loadable based on specified `model_name`. - revision (`str`, defaults to `None`): - If provided, the model will be loaded from a specific branch of the model repository. - variant (`str`, defaults to `None`): - Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk - storage requirements. - cache_dir (`str`, defaults to `None`): - The directory where the downloaded models and datasets will be stored, or loaded from. - tokenizer_id (`str`, defaults to `None`): - Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. - tokenizer_2_id (`str`, defaults to `None`): - Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. - tokenizer_3_id (`str`, defaults to `None`): - Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. - text_encoder_id (`str`, defaults to `None`): - Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. - text_encoder_2_id (`str`, defaults to `None`): - Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. - text_encoder_3_id (`str`, defaults to `None`): - Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. - transformer_id (`str`, defaults to `None`): - Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`. - vae_id (`str`, defaults to `None`): - Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`. - text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for the text encoder when generating text embeddings. - text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for the text encoder 2 when generating text embeddings. - text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for the text encoder 3 when generating text embeddings. - transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for the transformer model. - vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for the VAE model. - layerwise_upcasting_modules (`List[str]`, defaults to `[]`): - Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer']. - layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`): - Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2']. - layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`): - Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision - naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers - by default, and recommend adding more layers to the default list based on the model architecture. - - DATASET ARGUMENTS - ----------------- - dataset_config (`str`): - File to a dataset file containing information about training data. This file can contain information about one or - more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each - dictionary must contain the following keys: - - "data_root": (`str`) - The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided. - - "dataset_file": (`str`) - Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter - must be provided if `data_root` is not provided. - - "dataset_type": (`str`) - Type of dataset. Choose between ['image', 'video']. - - "id_token": (`str`) - Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training - for single subject/concept/style training, but is not necessary. - - "image_resolution_buckets": (`List[Tuple[int, int]]`) - Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple - represents the resolution (height, width). All images will be resized to the nearest bucket resolution. - This parameter must be provided if `dataset_type` is 'image'. - - "video_resolution_buckets": (`List[Tuple[int, int, int]]`) - Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple - represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket - resolution. This parameter must be provided if `dataset_type` is 'video'. - - "reshape_mode": (`str`) - All input images/videos are reshaped using this mode. Choose between the following: - ["center_crop", "random_crop", "bicubic"]. - - "remove_common_llm_caption_prefixes": (`boolean`) - Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes. - dataset_shuffle_buffer_size (`int`, defaults to `1`): - The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default - value of `1` means that the dataset will not be shuffled. - precomputation_items (`int`, defaults to `512`): - Number of data samples to precompute at once for memory-efficient training. The higher this value, - the more disk memory will be used to save the precomputed samples (conditions and latents). - precomputation_dir (`str`, defaults to `None`): - The directory where the precomputed samples will be stored. If not provided, the precomputed samples - will be stored in a temporary directory of the output directory. - precomputation_once (`bool`, defaults to `False`): - Precompute embeddings from all datasets at once before training. This is useful to save time during training - with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during - training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size - of your dataset(s). - - DATALOADER_ARGUMENTS - -------------------- - See https://pytorch.org/docs/stable/data.html for more information. - - dataloader_num_workers (`int`, defaults to `0`): - Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner - on the main process. - pin_memory (`bool`, defaults to `False`): - Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading. - - DIFFUSION ARGUMENTS - ------------------- - flow_resolution_shifting (`bool`, defaults to `False`): - Resolution-dependent shifting of timestep schedules. - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206). - TODO(aryan): We don't support this yet. - flow_base_seq_len (`int`, defaults to `256`): - Base number of tokens for images/video when applying resolution-dependent shifting. - flow_max_seq_len (`int`, defaults to `4096`): - Maximum number of tokens for images/video when applying resolution-dependent shifting. - flow_base_shift (`float`, defaults to `0.5`): - Base shift for timestep schedules when applying resolution-dependent shifting. - flow_max_shift (`float`, defaults to `1.15`): - Maximum shift for timestep schedules when applying resolution-dependent shifting. - flow_shift (`float`, defaults to `1.0`): - Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma). - Setting it higher is helpful when trying to train models for high-resolution generation or to produce better - samples in lower number of inference steps. - flow_weighting_scheme (`str`, defaults to `none`): - We default to the "none" weighting scheme for uniform sampling and uniform loss. - Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none']. - flow_logit_mean (`float`, defaults to `0.0`): - Mean to use when using the `'logit_normal'` weighting scheme. - flow_logit_std (`float`, defaults to `1.0`): - Standard deviation to use when using the `'logit_normal'` weighting scheme. - flow_mode_scale (`float`, defaults to `1.29`): - Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. - - TRAINING ARGUMENTS - ------------------ - training_type (`str`, defaults to `None`): - Type of training to perform. Choose between ['lora']. - seed (`int`, defaults to `42`): - A seed for reproducible training. - batch_size (`int`, defaults to `1`): - Per-device batch size. - train_steps (`int`, defaults to `1000`): - Total number of training steps to perform. - max_data_samples (`int`, defaults to `2**64`): - Maximum number of data samples observed during training training. If lesser than that required by `train_steps`, - the training will stop early. - gradient_accumulation_steps (`int`, defaults to `1`): - Number of gradients steps to accumulate before performing an optimizer step. - gradient_checkpointing (`bool`, defaults to `False`): - Whether or not to use gradient/activation checkpointing to save memory at the expense of slower - backward pass. - checkpointing_steps (`int`, defaults to `500`): - Save a checkpoint of the training state every X training steps. These checkpoints can be used both - as final checkpoints in case they are better than the last checkpoint, and are also suitable for - resuming training using `resume_from_checkpoint`. - checkpointing_limit (`int`, defaults to `None`): - Max number of checkpoints to store. - resume_from_checkpoint (`str`, defaults to `None`): - Whether training should be resumed from a previous checkpoint. Use a path saved by `checkpointing_steps`, - or `"latest"` to automatically select the last available checkpoint. - - OPTIMIZER ARGUMENTS - ------------------- - optimizer (`str`, defaults to `adamw`): - The optimizer type to use. Choose between the following: - - Torch optimizers: ["adam", "adamw"] - - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"] - lr (`float`, defaults to `1e-4`): - Initial learning rate (after the potential warmup period) to use. - lr_scheduler (`str`, defaults to `cosine_with_restarts`): - The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', - 'constant', 'constant_with_warmup']. - lr_warmup_steps (`int`, defaults to `500`): - Number of steps for the warmup in the lr scheduler. - lr_num_cycles (`int`, defaults to `1`): - Number of hard resets of the lr in cosine_with_restarts scheduler. - lr_power (`float`, defaults to `1.0`): - Power factor of the polynomial scheduler. - beta1 (`float`, defaults to `0.9`): - beta2 (`float`, defaults to `0.95`): - beta3 (`float`, defaults to `0.999`): - weight_decay (`float`, defaults to `0.0001`): - Penalty for large weights in the model. - epsilon (`float`, defaults to `1e-8`): - Small value to avoid division by zero in the optimizer. - max_grad_norm (`float`, defaults to `1.0`): - Maximum gradient norm to clip the gradients. - - VALIDATION ARGUMENTS - -------------------- - validation_dataset_file (`str`, defaults to `None`): - Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the - "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path" - will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path" - will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary. - The validation dataset file may contain other attributes specific to inference/validation such as: - - "height" and "width" and "num_frames": Resolution - - "num_inference_steps": Number of inference steps - - "guidance_scale": Classifier-free Guidance Scale - - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be - invoked with the sample dictionary to validate the sample.) - validation_steps (`int`, defaults to `500`): - Number of training steps after which a validation step is performed. - enable_model_cpu_offload (`bool`, defaults to `False`): - Whether or not to offload different modeling components to CPU during validation. - - MISCELLANEOUS ARGUMENTS - ----------------------- - tracker_name (`str`, defaults to `finetrainers`): - Name of the tracker/project to use for logging training metrics. - push_to_hub (`bool`, defaults to `False`): - Whether or not to push the model to the Hugging Face Hub. - hub_token (`str`, defaults to `None`): - The API token to use for pushing the model to the Hugging Face Hub. - hub_model_id (`str`, defaults to `None`): - The model identifier to use for pushing the model to the Hugging Face Hub. - output_dir (`str`, defaults to `None`): - The directory where the model checkpoints and logs will be stored. - logging_dir (`str`, defaults to `logs`): - The directory where the logs will be stored. - logging_steps (`int`, defaults to `1`): - Training logs will be tracked every `logging_steps` steps. - allow_tf32 (`bool`, defaults to `False`): - Whether or not to allow the use of TF32 matmul on compatible hardware. - nccl_timeout (`int`, defaults to `1800`): - Timeout for the NCCL communication. - report_to (`str`, defaults to `wandb`): - The name of the logger to use for logging training metrics. Choose between ['wandb']. - verbose (`int`, defaults to `1`): - Whether or not to print verbose logs. - - 0: Diffusers/Transformers warning logging on local main process only - - 1: Diffusers/Transformers info logging on local main process only - - 2: Diffusers/Transformers debug logging on local main process only - - 3: Diffusers/Transformers debug logging on all processes - """ - - # Parallel arguments - parallel_backend = ParallelBackendEnum.ACCELERATE - pp_degree: int = 1 - dp_degree: int = 1 - dp_shards: int = 1 - cp_degree: int = 1 - tp_degree: int = 1 - - # Model arguments - model_name: str = None - pretrained_model_name_or_path: str = None - revision: Optional[str] = None - variant: Optional[str] = None - cache_dir: Optional[str] = None - tokenizer_id: Optional[str] = None - tokenizer_2_id: Optional[str] = None - tokenizer_3_id: Optional[str] = None - text_encoder_id: Optional[str] = None - text_encoder_2_id: Optional[str] = None - text_encoder_3_id: Optional[str] = None - transformer_id: Optional[str] = None - vae_id: Optional[str] = None - text_encoder_dtype: torch.dtype = torch.bfloat16 - text_encoder_2_dtype: torch.dtype = torch.bfloat16 - text_encoder_3_dtype: torch.dtype = torch.bfloat16 - transformer_dtype: torch.dtype = torch.bfloat16 - vae_dtype: torch.dtype = torch.bfloat16 - layerwise_upcasting_modules: List[str] = [] - layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn - layerwise_upcasting_skip_modules_pattern: List[str] = [ - "patch_embed", - "pos_embed", - "x_embedder", - "context_embedder", - "time_embed", - "^proj_in$", - "^proj_out$", - "norm", - ] - - # Dataset arguments - dataset_config: str = None - dataset_shuffle_buffer_size: int = 1 - enable_precomputation: bool = False - precomputation_items: int = 512 - precomputation_dir: Optional[str] = None - precomputation_once: bool = False - - # Dataloader arguments - dataloader_num_workers: int = 0 - pin_memory: bool = False - - # Diffusion arguments - flow_resolution_shifting: bool = False - flow_base_seq_len: int = 256 - flow_max_seq_len: int = 4096 - flow_base_shift: float = 0.5 - flow_max_shift: float = 1.15 - flow_shift: float = 1.0 - flow_weighting_scheme: str = "none" - flow_logit_mean: float = 0.0 - flow_logit_std: float = 1.0 - flow_mode_scale: float = 1.29 - - # Training arguments - training_type: str = None - seed: int = 42 - batch_size: int = 1 - train_steps: int = 1000 - max_data_samples: int = 2**64 - gradient_accumulation_steps: int = 1 - gradient_checkpointing: bool = False - checkpointing_steps: int = 500 - checkpointing_limit: Optional[int] = None - resume_from_checkpoint: Optional[str] = None - enable_slicing: bool = False - enable_tiling: bool = False - - # Optimizer arguments - optimizer: str = "adamw" - lr: float = 1e-4 - lr_scheduler: str = "cosine_with_restarts" - lr_warmup_steps: int = 0 - lr_num_cycles: int = 1 - lr_power: float = 1.0 - beta1: float = 0.9 - beta2: float = 0.95 - beta3: float = 0.999 - weight_decay: float = 0.0001 - epsilon: float = 1e-8 - max_grad_norm: float = 1.0 - - # Validation arguments - validation_dataset_file: Optional[str] = None - validation_steps: int = 500 - enable_model_cpu_offload: bool = False - - # Miscellaneous arguments - tracker_name: str = "finetrainers" - push_to_hub: bool = False - hub_token: Optional[str] = None - hub_model_id: Optional[str] = None - output_dir: str = None - logging_dir: Optional[str] = "logs" - logging_steps: int = 1 - allow_tf32: bool = False - init_timeout: int = 300 # 5 minutes - nccl_timeout: int = 600 # 10 minutes, considering that validation may be performed - report_to: str = "wandb" - verbose: int = 1 - - def to_dict(self) -> Dict[str, Any]: - parallel_arguments = { - "pp_degree": self.pp_degree, - "dp_degree": self.dp_degree, - "dp_shards": self.dp_shards, - "cp_degree": self.cp_degree, - "tp_degree": self.tp_degree, - } - - model_arguments = { - "model_name": self.model_name, - "pretrained_model_name_or_path": self.pretrained_model_name_or_path, - "revision": self.revision, - "variant": self.variant, - "cache_dir": self.cache_dir, - "tokenizer_id": self.tokenizer_id, - "tokenizer_2_id": self.tokenizer_2_id, - "tokenizer_3_id": self.tokenizer_3_id, - "text_encoder_id": self.text_encoder_id, - "text_encoder_2_id": self.text_encoder_2_id, - "text_encoder_3_id": self.text_encoder_3_id, - "transformer_id": self.transformer_id, - "vae_id": self.vae_id, - "text_encoder_dtype": self.text_encoder_dtype, - "text_encoder_2_dtype": self.text_encoder_2_dtype, - "text_encoder_3_dtype": self.text_encoder_3_dtype, - "transformer_dtype": self.transformer_dtype, - "vae_dtype": self.vae_dtype, - "layerwise_upcasting_modules": self.layerwise_upcasting_modules, - "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype, - "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern, - } - model_arguments = get_non_null_items(model_arguments) - - dataset_arguments = { - "dataset_config": self.dataset_config, - "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size, - "enable_precomputation": self.enable_precomputation, - "precomputation_items": self.precomputation_items, - "precomputation_dir": self.precomputation_dir, - "precomputation_once": self.precomputation_once, - } - dataset_arguments = get_non_null_items(dataset_arguments) - - dataloader_arguments = { - "dataloader_num_workers": self.dataloader_num_workers, - "pin_memory": self.pin_memory, - } - - diffusion_arguments = { - "flow_resolution_shifting": self.flow_resolution_shifting, - "flow_base_seq_len": self.flow_base_seq_len, - "flow_max_seq_len": self.flow_max_seq_len, - "flow_base_shift": self.flow_base_shift, - "flow_max_shift": self.flow_max_shift, - "flow_shift": self.flow_shift, - "flow_weighting_scheme": self.flow_weighting_scheme, - "flow_logit_mean": self.flow_logit_mean, - "flow_logit_std": self.flow_logit_std, - "flow_mode_scale": self.flow_mode_scale, - } - - training_arguments = { - "training_type": self.training_type, - "seed": self.seed, - "batch_size": self.batch_size, - "train_steps": self.train_steps, - "max_data_samples": self.max_data_samples, - "gradient_accumulation_steps": self.gradient_accumulation_steps, - "gradient_checkpointing": self.gradient_checkpointing, - "checkpointing_steps": self.checkpointing_steps, - "checkpointing_limit": self.checkpointing_limit, - "resume_from_checkpoint": self.resume_from_checkpoint, - "enable_slicing": self.enable_slicing, - "enable_tiling": self.enable_tiling, - } - training_arguments = get_non_null_items(training_arguments) - - optimizer_arguments = { - "optimizer": self.optimizer, - "lr": self.lr, - "lr_scheduler": self.lr_scheduler, - "lr_warmup_steps": self.lr_warmup_steps, - "lr_num_cycles": self.lr_num_cycles, - "lr_power": self.lr_power, - "beta1": self.beta1, - "beta2": self.beta2, - "beta3": self.beta3, - "weight_decay": self.weight_decay, - "epsilon": self.epsilon, - "max_grad_norm": self.max_grad_norm, - } - optimizer_arguments = get_non_null_items(optimizer_arguments) - - validation_arguments = { - "validation_dataset_file": self.validation_dataset_file, - "validation_steps": self.validation_steps, - "enable_model_cpu_offload": self.enable_model_cpu_offload, - } - validation_arguments = get_non_null_items(validation_arguments) - - miscellaneous_arguments = { - "tracker_name": self.tracker_name, - "push_to_hub": self.push_to_hub, - "hub_token": self.hub_token, - "hub_model_id": self.hub_model_id, - "output_dir": self.output_dir, - "logging_dir": self.logging_dir, - "logging_steps": self.logging_steps, - "allow_tf32": self.allow_tf32, - "init_timeout": self.init_timeout, - "nccl_timeout": self.nccl_timeout, - "report_to": self.report_to, - "verbose": self.verbose, - } - miscellaneous_arguments = get_non_null_items(miscellaneous_arguments) - - return { - "parallel_arguments": parallel_arguments, - "model_arguments": model_arguments, - "dataset_arguments": dataset_arguments, - "dataloader_arguments": dataloader_arguments, - "diffusion_arguments": diffusion_arguments, - "training_arguments": training_arguments, - "optimizer_arguments": optimizer_arguments, - "validation_arguments": validation_arguments, - "miscellaneous_arguments": miscellaneous_arguments, - } - - def extend_args( - self, - add_fn: Callable[[argparse.ArgumentParser], None], - map_fn: Callable[["BaseArgs"], None], - validate_fn: Callable[["BaseArgs"], None], - ) -> None: - if not hasattr(self, "_extended_add_arguments"): - self._extended_add_arguments = [] - self._extended_add_arguments.append((add_fn, validate_fn, map_fn)) - - def parse_args(self): - _LIST_MODELS = "--list_models" - - parser = argparse.ArgumentParser() - - special_args = [_LIST_MODELS] - if any(arg in sys.argv for arg in special_args): - _add_helper_arguments(parser) - args = parser.parse_args() - _display_helper_messages(args) - sys.exit(0) - else: - _add_args(parser) - for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): - add_fn, _, _ = extended_add_arg_fns - add_fn(parser) - - args, remaining_args = parser.parse_known_args() - logger.debug(f"Remaining unparsed arguments: {remaining_args}") - - mapped_args = _map_to_args_type(args) - for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): - _, _, map_fn = extended_add_arg_fns - map_fn(args, mapped_args) - - _validate_args(mapped_args) - for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): - _, validate_fn, _ = extended_add_arg_fns - validate_fn(mapped_args) - - return mapped_args - - -def _add_args(parser: argparse.ArgumentParser) -> None: - _add_parallel_arguments(parser) - _add_model_arguments(parser) - _add_dataset_arguments(parser) - _add_dataloader_arguments(parser) - _add_diffusion_arguments(parser) - _add_training_arguments(parser) - _add_optimizer_arguments(parser) - _add_validation_arguments(parser) - _add_miscellaneous_arguments(parser) - - -def _validate_args(args: BaseArgs): - _validate_model_args(args) - _validate_dataset_args(args) - _validate_validation_args(args) - - -def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--parallel_backend", - type=str, - default=ParallelBackendEnum.ACCELERATE, - choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD], - ) - parser.add_argument("--pp_degree", type=int, default=1) - parser.add_argument("--dp_degree", type=int, default=1) - parser.add_argument("--dp_shards", type=int, default=1) - parser.add_argument("--cp_degree", type=int, default=1) - parser.add_argument("--tp_degree", type=int, default=1) - - -def _add_model_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()] - ) - parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) - parser.add_argument("--revision", type=str, default=None, required=False) - parser.add_argument("--variant", type=str, default=None) - parser.add_argument("--cache_dir", type=str, default=None) - parser.add_argument("--tokenizer_id", type=str, default=None) - parser.add_argument("--tokenizer_2_id", type=str, default=None) - parser.add_argument("--tokenizer_3_id", type=str, default=None) - parser.add_argument("--text_encoder_id", type=str, default=None) - parser.add_argument("--text_encoder_2_id", type=str, default=None) - parser.add_argument("--text_encoder_3_id", type=str, default=None) - parser.add_argument("--transformer_id", type=str, default=None) - parser.add_argument("--vae_id", type=str, default=None) - parser.add_argument("--text_encoder_dtype", type=str, default="bf16") - parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16") - parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16") - parser.add_argument("--transformer_dtype", type=str, default="bf16") - parser.add_argument("--vae_dtype", type=str, default="bf16") - parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"]) - parser.add_argument( - "--layerwise_upcasting_storage_dtype", - type=str, - default="float8_e4m3fn", - choices=["float8_e4m3fn", "float8_e5m2"], - ) - parser.add_argument( - "--layerwise_upcasting_skip_modules_pattern", - type=str, - default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"], - nargs="+", - ) - - -def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--dataset_config", type=str, required=True) - parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1) - parser.add_argument("--enable_precomputation", action="store_true") - parser.add_argument("--precomputation_items", type=int, default=512) - parser.add_argument("--precomputation_dir", type=str, default=None) - parser.add_argument("--precomputation_once", action="store_true") - - -def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--dataloader_num_workers", type=int, default=0) - parser.add_argument("--pin_memory", action="store_true") - - -def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--flow_resolution_shifting", action="store_true") - parser.add_argument("--flow_base_seq_len", type=int, default=256) - parser.add_argument("--flow_max_seq_len", type=int, default=4096) - parser.add_argument("--flow_base_shift", type=float, default=0.5) - parser.add_argument("--flow_max_shift", type=float, default=1.15) - parser.add_argument("--flow_shift", type=float, default=1.0) - parser.add_argument( - "--flow_weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument("--flow_logit_mean", type=float, default=0.0) - parser.add_argument("--flow_logit_std", type=float, default=1.0) - parser.add_argument("--flow_mode_scale", type=float, default=1.29) - - -def _add_training_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True - ) - parser.add_argument("--seed", type=int, default=None) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--train_steps", type=int, default=1000) - parser.add_argument("--max_data_samples", type=int, default=2**64) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--gradient_checkpointing", action="store_true") - parser.add_argument("--checkpointing_steps", type=int, default=500) - parser.add_argument("--checkpointing_limit", type=int, default=None) - parser.add_argument("--resume_from_checkpoint", type=str, default=None) - parser.add_argument("--enable_slicing", action="store_true") - parser.add_argument("--enable_tiling", action="store_true") - - -def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--lr_scheduler", type=str, default="constant") - parser.add_argument("--lr_warmup_steps", type=int, default=500) - parser.add_argument("--lr_num_cycles", type=int, default=1) - parser.add_argument("--lr_power", type=float, default=1.0) - parser.add_argument( - "--optimizer", - type=lambda s: s.lower(), - default="adam", - choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"], - ) - parser.add_argument("--beta1", type=float, default=0.9) - parser.add_argument("--beta2", type=float, default=0.95) - parser.add_argument("--beta3", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=1e-04) - parser.add_argument("--epsilon", type=float, default=1e-8) - parser.add_argument("--max_grad_norm", default=1.0, type=float) - - -def _add_validation_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--validation_dataset_file", type=str, default=None) - parser.add_argument("--validation_steps", type=int, default=500) - parser.add_argument("--enable_model_cpu_offload", action="store_true") - - -def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--tracker_name", type=str, default="finetrainers") - parser.add_argument("--push_to_hub", action="store_true") - parser.add_argument("--hub_token", type=str, default=None) - parser.add_argument("--hub_model_id", type=str, default=None) - parser.add_argument("--output_dir", type=str, default="finetrainers-training") - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument("--logging_steps", type=int, default=1) - parser.add_argument("--allow_tf32", action="store_true") - parser.add_argument("--init_timeout", type=int, default=300) - parser.add_argument("--nccl_timeout", type=int, default=600) - parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"]) - parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3]) - - -def _add_helper_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--list_models", action="store_true") - - -_DTYPE_MAP = { - "bf16": torch.bfloat16, - "fp16": torch.float16, - "fp32": torch.float32, - "float8_e4m3fn": torch.float8_e4m3fn, - "float8_e5m2": torch.float8_e5m2, -} - - -def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs: - result_args = BaseArgs() - - # Parallel arguments - result_args.parallel_backend = args.parallel_backend - result_args.pp_degree = args.pp_degree - result_args.dp_degree = args.dp_degree - result_args.dp_shards = args.dp_shards - result_args.cp_degree = args.cp_degree - result_args.tp_degree = args.tp_degree - - # Model arguments - result_args.model_name = args.model_name - result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path - result_args.revision = args.revision - result_args.variant = args.variant - result_args.cache_dir = args.cache_dir - result_args.tokenizer_id = args.tokenizer_id - result_args.tokenizer_2_id = args.tokenizer_2_id - result_args.tokenizer_3_id = args.tokenizer_3_id - result_args.text_encoder_id = args.text_encoder_id - result_args.text_encoder_2_id = args.text_encoder_2_id - result_args.text_encoder_3_id = args.text_encoder_3_id - result_args.transformer_id = args.transformer_id - result_args.vae_id = args.vae_id - result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype] - result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype] - result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype] - result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype] - result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype] - result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules - result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype] - result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern - - # Dataset arguments - result_args.dataset_config = args.dataset_config - result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size - result_args.enable_precomputation = args.enable_precomputation - result_args.precomputation_items = args.precomputation_items - result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed") - result_args.precomputation_once = args.precomputation_once - - # Dataloader arguments - result_args.dataloader_num_workers = args.dataloader_num_workers - result_args.pin_memory = args.pin_memory - - # Diffusion arguments - result_args.flow_resolution_shifting = args.flow_resolution_shifting - result_args.flow_base_seq_len = args.flow_base_seq_len - result_args.flow_max_seq_len = args.flow_max_seq_len - result_args.flow_base_shift = args.flow_base_shift - result_args.flow_max_shift = args.flow_max_shift - result_args.flow_shift = args.flow_shift - result_args.flow_weighting_scheme = args.flow_weighting_scheme - result_args.flow_logit_mean = args.flow_logit_mean - result_args.flow_logit_std = args.flow_logit_std - result_args.flow_mode_scale = args.flow_mode_scale - - # Training arguments - result_args.training_type = args.training_type - result_args.seed = args.seed - result_args.batch_size = args.batch_size - result_args.train_steps = args.train_steps - result_args.max_data_samples = args.max_data_samples - result_args.gradient_accumulation_steps = args.gradient_accumulation_steps - result_args.gradient_checkpointing = args.gradient_checkpointing - result_args.checkpointing_steps = args.checkpointing_steps - result_args.checkpointing_limit = args.checkpointing_limit - result_args.resume_from_checkpoint = args.resume_from_checkpoint - result_args.enable_slicing = args.enable_slicing - result_args.enable_tiling = args.enable_tiling - - # Optimizer arguments - result_args.optimizer = args.optimizer or "adamw" - result_args.lr = args.lr or 1e-4 - result_args.lr_scheduler = args.lr_scheduler - result_args.lr_warmup_steps = args.lr_warmup_steps - result_args.lr_num_cycles = args.lr_num_cycles - result_args.lr_power = args.lr_power - result_args.beta1 = args.beta1 - result_args.beta2 = args.beta2 - result_args.beta3 = args.beta3 - result_args.weight_decay = args.weight_decay - result_args.epsilon = args.epsilon - result_args.max_grad_norm = args.max_grad_norm - - # Validation arguments - result_args.validation_dataset_file = args.validation_dataset_file - result_args.validation_steps = args.validation_steps - result_args.enable_model_cpu_offload = args.enable_model_cpu_offload - - # Miscellaneous arguments - result_args.tracker_name = args.tracker_name - result_args.push_to_hub = args.push_to_hub - result_args.hub_token = args.hub_token - result_args.hub_model_id = args.hub_model_id - result_args.output_dir = args.output_dir - result_args.logging_dir = args.logging_dir - result_args.logging_steps = args.logging_steps - result_args.allow_tf32 = args.allow_tf32 - result_args.init_timeout = args.init_timeout - result_args.nccl_timeout = args.nccl_timeout - result_args.report_to = args.report_to - result_args.verbose = args.verbose - - return result_args - - -def _validate_model_args(args: BaseArgs): - if args.training_type == "full-finetune": - assert ( - "transformer" not in args.layerwise_upcasting_modules - ), "Layerwise upcasting is not supported for full-finetune training" - - -def _validate_dataset_args(args: BaseArgs): - dataset_config = pathlib.Path(args.dataset_config) - if not dataset_config.exists(): - raise ValueError(f"Dataset config file {args.dataset_config} does not exist.") - if args.dataset_shuffle_buffer_size < 1: - raise ValueError("Dataset shuffle buffer size must be greater than 0.") - if args.precomputation_items < 1: - raise ValueError("Precomputation items must be greater than 0.") - - -def _validate_validation_args(args: BaseArgs): - if args.enable_model_cpu_offload: - if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.cp_degree, args.tp_degree]): - raise ValueError("Model CPU offload is not supported on multi-GPU at the moment.") - - -def _display_helper_messages(args: argparse.Namespace): - if args.list_models: - print("Supported models:") - for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()): - print(f" {index + 1}. {model_name}") diff --git a/finetrainers/config.py b/finetrainers/config.py deleted file mode 100644 index 0c9d0f71c682d71413d3029edf40930382d3fcfa..0000000000000000000000000000000000000000 --- a/finetrainers/config.py +++ /dev/null @@ -1,58 +0,0 @@ -from enum import Enum -from typing import Type - -from .models import ModelSpecification -from .models.cogvideox import CogVideoXModelSpecification -from .models.cogview4 import CogView4ModelSpecification -from .models.hunyuan_video import HunyuanVideoModelSpecification -from .models.ltx_video import LTXVideoModelSpecification -from .models.wan import WanModelSpecification - - -class ModelType(str, Enum): - COGVIDEOX = "cogvideox" - COGVIEW4 = "cogview4" - HUNYUAN_VIDEO = "hunyuan_video" - LTX_VIDEO = "ltx_video" - WAN = "wan" - - -class TrainingType(str, Enum): - LORA = "lora" - FULL_FINETUNE = "full-finetune" - - -SUPPORTED_MODEL_CONFIGS = { - ModelType.COGVIDEOX: { - TrainingType.LORA: CogVideoXModelSpecification, - TrainingType.FULL_FINETUNE: CogVideoXModelSpecification, - }, - ModelType.COGVIEW4: { - TrainingType.LORA: CogView4ModelSpecification, - TrainingType.FULL_FINETUNE: CogView4ModelSpecification, - }, - ModelType.HUNYUAN_VIDEO: { - TrainingType.LORA: HunyuanVideoModelSpecification, - TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, - }, - ModelType.LTX_VIDEO: { - TrainingType.LORA: LTXVideoModelSpecification, - TrainingType.FULL_FINETUNE: LTXVideoModelSpecification, - }, - ModelType.WAN: { - TrainingType.LORA: WanModelSpecification, - TrainingType.FULL_FINETUNE: WanModelSpecification, - }, -} - - -def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]: - if model_name not in SUPPORTED_MODEL_CONFIGS: - raise ValueError( - f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" - ) - if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: - raise ValueError( - f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" - ) - return SUPPORTED_MODEL_CONFIGS[model_name][training_type] diff --git a/finetrainers/constants.py b/finetrainers/constants.py deleted file mode 100644 index 693495ca0e91617dce6c35583b2e1a18c9025708..0000000000000000000000000000000000000000 --- a/finetrainers/constants.py +++ /dev/null @@ -1,83 +0,0 @@ -import os - - -DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -DEFAULT_FRAME_BUCKETS = [49] - -DEFAULT_IMAGE_RESOLUTION_BUCKETS = [] -for height in DEFAULT_HEIGHT_BUCKETS: - for width in DEFAULT_WIDTH_BUCKETS: - DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width)) - -DEFAULT_VIDEO_RESOLUTION_BUCKETS = [] -for frames in DEFAULT_FRAME_BUCKETS: - for height in DEFAULT_HEIGHT_BUCKETS: - for width in DEFAULT_WIDTH_BUCKETS: - DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width)) - - -FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO") - -PRECOMPUTED_DIR_NAME = "precomputed" -PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions" -PRECOMPUTED_LATENTS_DIR_NAME = "latents" - -MODEL_DESCRIPTION = r""" -\# {model_id} {training_type} finetune - - - -\#\# Model Description - -This model is a {training_type} of the `{model_id}` model. - -This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers). - -\#\# Download model - -[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. - -\#\# Usage - -Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed. - -```python -{model_example} -``` - -For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. - -\#\# License - -Please adhere to the license of the base model. -""".strip() - -_COMMON_BEGINNING_PHRASES = ( - "This video", - "The video", - "This clip", - "The clip", - "The animation", - "This image", - "The image", - "This picture", - "The picture", -) -_COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents") - -COMMON_LLM_START_PHRASES = ( - "In the video,", - "In this video,", - "In this video clip,", - "In the clip,", - "Caption:", - *( - f"{beginning} {continuation}" - for beginning in _COMMON_BEGINNING_PHRASES - for continuation in _COMMON_CONTINUATION_WORDS - ), -) - -SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png") -SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov") diff --git a/finetrainers/data/__init__.py b/finetrainers/data/__init__.py deleted file mode 100644 index ba6e1f6df867f44e59c16035651b120ccaa027a7..0000000000000000000000000000000000000000 --- a/finetrainers/data/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from ._artifact import ImageArtifact, VideoArtifact -from .dataloader import DPDataLoader -from .dataset import ( - ImageCaptionFilePairDataset, - ImageFileCaptionFileListDataset, - ImageFolderDataset, - ImageWebDataset, - ValidationDataset, - VideoCaptionFilePairDataset, - VideoFileCaptionFileListDataset, - VideoFolderDataset, - VideoWebDataset, - combine_datasets, - initialize_dataset, - wrap_iterable_dataset_for_preprocessing, -) -from .precomputation import ( - InMemoryDataIterable, - InMemoryDistributedDataPreprocessor, - InMemoryOnceDataIterable, - PrecomputedDataIterable, - PrecomputedDistributedDataPreprocessor, - PrecomputedOnceDataIterable, - initialize_preprocessor, -) -from .sampler import ResolutionSampler -from .utils import find_files diff --git a/finetrainers/data/_artifact.py b/finetrainers/data/_artifact.py deleted file mode 100644 index 400f25d143f5062d77ed6391ca9862654d295de7..0000000000000000000000000000000000000000 --- a/finetrainers/data/_artifact.py +++ /dev/null @@ -1,29 +0,0 @@ -# ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT ===== - -from dataclasses import dataclass -from typing import Any, List - -from PIL.Image import Image - - -@dataclass -class Artifact: - type: str - value: Any - file_extension: str - - -@dataclass -class ImageArtifact(Artifact): - value: Image - - def __init__(self, value: Image): - super().__init__(type="image", value=value, file_extension="png") - - -@dataclass -class VideoArtifact(Artifact): - value: List[Image] - - def __init__(self, value: List[Image]): - super().__init__(type="video", value=value, file_extension="mp4") diff --git a/finetrainers/data/dataloader.py b/finetrainers/data/dataloader.py deleted file mode 100644 index a8b0a4b1f6253bd943cf9fe7b9e31c06aa060b35..0000000000000000000000000000000000000000 --- a/finetrainers/data/dataloader.py +++ /dev/null @@ -1,40 +0,0 @@ -import pickle -from typing import Any, Dict - -import torch.distributed.checkpoint.stateful -import torchdata.stateful_dataloader - -from ..logging import get_logger - - -logger = get_logger() - - -class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful): - def __init__( - self, - rank: int, - dataset: torch.utils.data.IterableDataset, - batch_size: int = 1, - num_workers: int = 0, - collate_fn=None, - ) -> None: - super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn) - - self._dp_rank = rank - self._rank_id = f"dp_rank_{rank}" - - def state_dict(self) -> Dict[str, Any]: - # Store state only for dp rank to avoid replicating the same state across other dimensions - return {self._rank_id: pickle.dumps(super().state_dict())} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - # State being empty is valid - if not state_dict: - return - - if self._rank_id not in state_dict: - logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}") - return - - super().load_state_dict(pickle.loads(state_dict[self._rank_id])) diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py deleted file mode 100644 index 3cd34e4c3f07f6423172524ac01209183bbb14fe..0000000000000000000000000000000000000000 --- a/finetrainers/data/dataset.py +++ /dev/null @@ -1,978 +0,0 @@ -import pathlib -import random -from typing import Any, Dict, List, Optional, Tuple, Union - -import datasets -import datasets.data_files -import datasets.distributed -import datasets.exceptions -import huggingface_hub -import huggingface_hub.errors -import numpy as np -import PIL.Image -import torch -import torch.distributed.checkpoint.stateful -from diffusers.utils import load_image, load_video -from huggingface_hub import list_repo_files, repo_exists, snapshot_download -from tqdm.auto import tqdm - -from .. import constants -from .. import functional as FF -from ..logging import get_logger -from . import utils - - -import decord # isort:skip - -decord.bridge.set_bridge("torch") - -logger = get_logger() - - -# fmt: off -MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024 -COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"] -COMMON_VIDEO_FILES = ["video.txt", "videos.txt"] -COMMON_IMAGE_FILES = ["image.txt", "images.txt"] -COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"] -# fmt: on - - -class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__(self, root: str, infinite: bool = False) -> None: - super().__init__() - - self.root = pathlib.Path(root) - self.infinite = infinite - - data = [] - caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0)) - for caption_file in caption_files: - data_file = self._find_data_file(caption_file) - if data_file: - data.append( - { - "caption": (self.root / caption_file).as_posix(), - "image": (self.root / data_file).as_posix(), - } - ) - - data = datasets.Dataset.from_list(data) - data = data.cast_column("image", datasets.Image(mode="RGB")) - - self._data = data.to_iterable_dataset() - self._sample_index = 0 - self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - sample["caption"] = _read_caption_from_file(sample["caption"]) - sample["image"] = _preprocess_image(sample["image"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") - break - else: - self._sample_index = 0 - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - def _find_data_file(self, caption_file: str) -> str: - caption_file = pathlib.Path(caption_file) - data_file = None - found_data = 0 - - for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: - image_filename = caption_file.with_suffix(f".{extension}") - if image_filename.exists(): - found_data += 1 - data_file = image_filename - - if found_data == 0: - return False - elif found_data > 1: - raise ValueError( - f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " - f"file per caption file. The following extensions are supported:\n" - f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n" - ) - - return data_file.as_posix() - - -class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__(self, root: str, infinite: bool = False) -> None: - super().__init__() - - self.root = pathlib.Path(root) - self.infinite = infinite - - data = [] - caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0)) - for caption_file in caption_files: - data_file = self._find_data_file(caption_file) - if data_file: - data.append( - { - "caption": (self.root / caption_file).as_posix(), - "video": (self.root / data_file).as_posix(), - } - ) - - data = datasets.Dataset.from_list(data) - data = data.cast_column("video", datasets.Video()) - - self._data = data.to_iterable_dataset() - self._sample_index = 0 - self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - sample["caption"] = _read_caption_from_file(sample["caption"]) - sample["video"] = _preprocess_video(sample["video"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") - break - else: - self._sample_index = 0 - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - def _find_data_file(self, caption_file: str) -> str: - caption_file = pathlib.Path(caption_file) - data_file = None - found_data = 0 - - for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: - video_filename = caption_file.with_suffix(f".{extension}") - if video_filename.exists(): - found_data += 1 - data_file = video_filename - - if found_data == 0: - return False - elif found_data > 1: - raise ValueError( - f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " - f"file per caption file. The following extensions are supported:\n" - f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n" - ) - - return data_file.as_posix() - - -class ImageFileCaptionFileListDataset( - torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful -): - def __init__(self, root: str, infinite: bool = False) -> None: - super().__init__() - - VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] - VALID_IMAGE_FILES = ["image.txt", "images.txt"] - - self.root = pathlib.Path(root) - self.infinite = infinite - - data = [] - existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] - existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()] - - if len(existing_caption_files) == 0: - raise FileNotFoundError( - f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" - ) - if len(existing_image_files) == 0: - raise FileNotFoundError( - f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" - ) - if len(existing_caption_files) > 1: - raise ValueError( - f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" - ) - if len(existing_image_files) > 1: - raise ValueError( - f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" - ) - - caption_file = existing_caption_files[0] - image_file = existing_image_files[0] - - with open((self.root / caption_file).as_posix(), "r") as f: - captions = f.read().splitlines() - with open((self.root / image_file).as_posix(), "r") as f: - images = f.read().splitlines() - images = [(self.root / image).as_posix() for image in images] - - if len(captions) != len(images): - raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})") - - for caption, image in zip(captions, images): - data.append({"caption": caption, "image": image}) - - data = datasets.Dataset.from_list(data) - data = data.cast_column("image", datasets.Image(mode="RGB")) - - self._data = data.to_iterable_dataset() - self._sample_index = 0 - self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - sample["image"] = _preprocess_image(sample["image"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") - break - else: - self._sample_index = 0 - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - -class VideoFileCaptionFileListDataset( - torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful -): - def __init__(self, root: str, infinite: bool = False) -> None: - super().__init__() - - VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] - VALID_VIDEO_FILES = ["video.txt", "videos.txt"] - - self.root = pathlib.Path(root) - self.infinite = infinite - - data = [] - existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] - existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()] - - if len(existing_caption_files) == 0: - raise FileNotFoundError( - f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" - ) - if len(existing_video_files) == 0: - raise FileNotFoundError( - f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" - ) - if len(existing_caption_files) > 1: - raise ValueError( - f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" - ) - if len(existing_video_files) > 1: - raise ValueError( - f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" - ) - - caption_file = existing_caption_files[0] - video_file = existing_video_files[0] - - with open((self.root / caption_file).as_posix(), "r") as f: - captions = f.read().splitlines() - with open((self.root / video_file).as_posix(), "r") as f: - videos = f.read().splitlines() - videos = [(self.root / video).as_posix() for video in videos] - - if len(captions) != len(videos): - raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})") - - for caption, video in zip(captions, videos): - data.append({"caption": caption, "video": video}) - - data = datasets.Dataset.from_list(data) - data = data.cast_column("video", datasets.Video()) - - self._data = data.to_iterable_dataset() - self._sample_index = 0 - self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - sample["video"] = _preprocess_video(sample["video"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") - break - else: - self._sample_index = 0 - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - -class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__(self, root: str, infinite: bool = False) -> None: - super().__init__() - - self.root = pathlib.Path(root) - self.infinite = infinite - - data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train") - - self._data = data.to_iterable_dataset() - self._sample_index = 0 - self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - sample["image"] = _preprocess_image(sample["image"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") - break - else: - self._sample_index = 0 - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - -class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__(self, root: str, infinite: bool = False) -> None: - super().__init__() - - self.root = pathlib.Path(root) - self.infinite = infinite - - data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train") - - self._data = data.to_iterable_dataset() - self._sample_index = 0 - self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - sample["video"] = _preprocess_video(sample["video"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") - break - else: - self._sample_index = 0 - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - -class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__( - self, - dataset_name: str, - infinite: bool = False, - column_names: Union[str, List[str]] = "__auto__", - weights: Dict[str, float] = -1, - **kwargs, - ) -> None: - super().__init__() - - assert weights == -1 or isinstance( - weights, dict - ), "`weights` must be a dictionary of probabilities for each caption column" - - self.dataset_name = dataset_name - self.infinite = infinite - - data = datasets.load_dataset(dataset_name, split="train", streaming=True) - - if column_names == "__auto__": - if weights == -1: - caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES] - if len(caption_columns) == 0: - raise ValueError( - f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}" - ) - weights = [1] * len(caption_columns) - else: - caption_columns = list(weights.keys()) - weights = list(weights.values()) - if not all(column in data.column_names for column in caption_columns): - raise ValueError( - f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}" - ) - else: - if isinstance(column_names, str): - if column_names not in data.column_names: - raise ValueError( - f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}" - ) - caption_columns = [column_names] - weights = [1] if weights == -1 else [weights.get(column_names)] - elif isinstance(column_names, list): - if not all(column in data.column_names for column in column_names): - raise ValueError( - f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}" - ) - caption_columns = column_names - weights = [1] if weights == -1 else [weights.get(column) for column in column_names] - else: - raise ValueError(f"Unsupported type for column_name: {type(column_names)}") - - for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: - if column_names in data.column_names: - data = data.cast_column(column_names, datasets.Image(mode="RGB")) - data = data.rename_column(column_names, "image") - break - - self._data = data - self._sample_index = 0 - self._precomputable_once = False - self._caption_columns = caption_columns - self._weights = weights - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0] - sample["caption"] = sample[caption_column] - sample["image"] = _preprocess_image(sample["image"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset {self.dataset_name} has run out of data") - break - else: - # Reset offset for the next iteration - self._sample_index = 0 - logger.warning(f"Dataset {self.dataset_name} is being re-looped") - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - -class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__( - self, - dataset_name: str, - infinite: bool = False, - column_names: Union[str, List[str]] = "__auto__", - weights: Dict[str, float] = -1, - **kwargs, - ) -> None: - super().__init__() - - assert weights == -1 or isinstance( - weights, dict - ), "`weights` must be a dictionary of probabilities for each caption column" - - self.dataset_name = dataset_name - self.infinite = infinite - - data = datasets.load_dataset(dataset_name, split="train", streaming=True) - - if column_names == "__auto__": - if weights == -1: - caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES] - if len(caption_columns) == 0: - raise ValueError( - f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}" - ) - weights = [1] * len(caption_columns) - else: - caption_columns = list(weights.keys()) - weights = list(weights.values()) - if not all(column in data.column_names for column in caption_columns): - raise ValueError( - f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}" - ) - else: - if isinstance(column_names, str): - if column_names not in data.column_names: - raise ValueError( - f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}" - ) - caption_columns = [column_names] - weights = [1] if weights == -1 else [weights.get(column_names)] - elif isinstance(column_names, list): - if not all(column in data.column_names for column in column_names): - raise ValueError( - f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}" - ) - caption_columns = column_names - weights = [1] if weights == -1 else [weights.get(column) for column in column_names] - else: - raise ValueError(f"Unsupported type for column_name: {type(column_names)}") - - for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: - if column_names in data.column_names: - data = data.cast_column(column_names, datasets.Video()) - data = data.rename_column(column_names, "video") - break - - self._data = data - self._sample_index = 0 - self._precomputable_once = False - self._caption_columns = caption_columns - self._weights = weights - - def _get_data_iter(self): - if self._sample_index == 0: - return iter(self._data) - return iter(self._data.skip(self._sample_index)) - - def __iter__(self): - while True: - for sample in self._get_data_iter(): - self._sample_index += 1 - caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0] - sample["caption"] = sample[caption_column] - sample["video"] = _preprocess_video(sample["video"]) - yield sample - - if not self.infinite: - logger.warning(f"Dataset {self.dataset_name} has run out of data") - break - else: - # Reset offset for the next iteration - self._sample_index = 0 - logger.warning(f"Dataset {self.dataset_name} is being re-looped") - - def load_state_dict(self, state_dict): - self._sample_index = state_dict["sample_index"] - - def state_dict(self): - return {"sample_index": self._sample_index} - - -class ValidationDataset(torch.utils.data.IterableDataset): - def __init__(self, filename: str): - super().__init__() - - self.filename = pathlib.Path(filename) - - if not self.filename.exists(): - raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist") - - if self.filename.suffix == ".csv": - data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train") - elif self.filename.suffix == ".json": - data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data") - elif self.filename.suffix == ".parquet": - data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train") - elif self.filename.suffix == ".arrow": - data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train") - else: - _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"] - raise ValueError( - f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}" - ) - - self._data = data.to_iterable_dataset() - - def __iter__(self): - for sample in self._data: - # For consistency reasons, we mandate that "caption" is always present in the validation dataset. - # However, since the model specifications use "prompt", we create an alias here. - sample["prompt"] = sample["caption"] - - # Load image or video if the path is provided - # TODO(aryan): need to handle custom columns here for control conditions - sample["image"] = None - sample["video"] = None - - if sample.get("image_path", None) is not None: - image_path = pathlib.Path(sample["image_path"]) - if not image_path.is_file(): - logger.warning(f"Image file {image_path.as_posix()} does not exist.") - else: - sample["image"] = load_image(sample["image_path"]) - - if sample.get("video_path", None) is not None: - video_path = pathlib.Path(sample["video_path"]) - if not video_path.is_file(): - logger.warning(f"Video file {video_path.as_posix()} does not exist.") - else: - sample["video"] = load_video(sample["video_path"]) - - sample = {k: v for k, v in sample.items() if v is not None} - yield sample - - -class IterableDatasetPreprocessingWrapper( - torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful -): - def __init__( - self, - dataset: torch.utils.data.IterableDataset, - dataset_type: str, - id_token: Optional[str] = None, - image_resolution_buckets: List[Tuple[int, int]] = None, - video_resolution_buckets: List[Tuple[int, int, int]] = None, - reshape_mode: str = "bicubic", - remove_common_llm_caption_prefixes: bool = False, - **kwargs, - ): - super().__init__() - - self.dataset = dataset - self.dataset_type = dataset_type - self.id_token = id_token - self.image_resolution_buckets = image_resolution_buckets - self.video_resolution_buckets = video_resolution_buckets - self.reshape_mode = reshape_mode - self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes - - logger.info( - f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n" - f" - Dataset Type: {dataset_type}\n" - f" - ID Token: {id_token}\n" - f" - Image Resolution Buckets: {image_resolution_buckets}\n" - f" - Video Resolution Buckets: {video_resolution_buckets}\n" - f" - Reshape Mode: {reshape_mode}\n" - f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n" - ) - - def __iter__(self): - logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset") - for sample in iter(self.dataset): - if self.dataset_type == "image": - if self.image_resolution_buckets: - sample["_original_num_frames"] = 1 - sample["_original_height"] = sample["image"].size(1) - sample["_original_width"] = sample["image"].size(2) - sample["image"] = FF.resize_to_nearest_bucket_image( - sample["image"], self.image_resolution_buckets, self.reshape_mode - ) - elif self.dataset_type == "video": - if self.video_resolution_buckets: - sample["_original_num_frames"] = sample["video"].size(0) - sample["_original_height"] = sample["video"].size(2) - sample["_original_width"] = sample["video"].size(3) - sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video( - sample["video"], self.video_resolution_buckets, self.reshape_mode - ) - if _first_frame_only: - msg = ( - "The number of frames in the video is less than the minimum bucket size " - "specified. The first frame is being used as a single frame video. This " - "message is logged at the first occurence and for every 128th occurence " - "after that." - ) - logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128) - sample["video"] = sample["video"][0] - - if self.remove_common_llm_caption_prefixes: - sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES) - - if self.id_token is not None: - sample["caption"] = f"{self.id_token} {sample['caption']}" - - yield sample - - def load_state_dict(self, state_dict): - self.dataset.load_state_dict(state_dict["dataset"]) - - def state_dict(self): - return {"dataset": self.dataset.state_dict()} - - -class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): - def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False): - super().__init__() - - self.datasets = datasets - self.buffer_size = buffer_size - self.shuffle = shuffle - - logger.info( - f"Initializing IterableCombinedDataset with the following configuration:\n" - f" - Number of Datasets: {len(datasets)}\n" - f" - Buffer Size: {buffer_size}\n" - f" - Shuffle: {shuffle}\n" - ) - - def __iter__(self): - logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets") - iterators = [iter(dataset) for dataset in self.datasets] - buffer = [] - per_iter = max(1, self.buffer_size // len(iterators)) - - for index, it in enumerate(iterators): - for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"): - try: - buffer.append((it, next(it))) - except StopIteration: - continue - - while len(buffer) > 0: - idx = 0 - if self.shuffle: - idx = random.randint(0, len(buffer) - 1) - current_it, sample = buffer.pop(idx) - yield sample - try: - buffer.append((current_it, next(current_it))) - except StopIteration: - pass - - def load_state_dict(self, state_dict): - for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]): - dataset.load_state_dict(dataset_state_dict) - - def state_dict(self): - return {"datasets": [dataset.state_dict() for dataset in self.datasets]} - - -# TODO(aryan): maybe write a test for this -def initialize_dataset( - dataset_name_or_root: str, - dataset_type: str = "video", - streaming: bool = True, - infinite: bool = False, - *, - _caption_options: Optional[Dict[str, Any]] = None, -) -> torch.utils.data.IterableDataset: - assert dataset_type in ["image", "video"] - - try: - does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset") - except huggingface_hub.errors.HFValidationError: - does_repo_exist_on_hub = False - - if does_repo_exist_on_hub: - return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options) - else: - return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite) - - -def combine_datasets( - datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False -) -> torch.utils.data.IterableDataset: - return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle) - - -def wrap_iterable_dataset_for_preprocessing( - dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any] -) -> torch.utils.data.IterableDataset: - return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config) - - -def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False): - root = pathlib.Path(dataset_name_or_root) - supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"] - metadata_files = [root / metadata_file for metadata_file in supported_metadata_files] - metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()] - - if len(metadata_files) > 1: - raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.") - - if len(metadata_files) == 1: - if dataset_type == "image": - dataset = ImageFolderDataset(root.as_posix(), infinite=infinite) - else: - dataset = VideoFolderDataset(root.as_posix(), infinite=infinite) - return dataset - - if _has_data_caption_file_pairs(root, remote=False): - if dataset_type == "image": - dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite) - else: - dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite) - elif _has_data_file_caption_file_lists(root, remote=False): - if dataset_type == "image": - dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite) - else: - dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite) - else: - raise ValueError( - f"Could not find any supported dataset structure in the directory {root}. Please open an issue at " - f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will " - f"help you set it up." - ) - - return dataset - - -def _initialize_hub_dataset( - dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None -): - repo_file_list = list_repo_files(dataset_name, repo_type="dataset") - if _has_data_caption_file_pairs(repo_file_list, remote=True): - return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) - elif _has_data_file_caption_file_lists(repo_file_list, remote=True): - return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) - - has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list) - if has_tar_files: - return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options) - - # TODO(aryan): This should be improved - caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")] - if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT: - try: - dataset_root = snapshot_download(dataset_name, repo_type="dataset") - if dataset_type == "image": - dataset = ImageFolderDataset(dataset_root, infinite=infinite) - else: - dataset = VideoFolderDataset(dataset_root, infinite=infinite) - return dataset - except Exception: - pass - - raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub") - - -def _initialize_data_caption_file_dataset_from_hub( - dataset_name: str, dataset_type: str, infinite: bool = False -) -> torch.utils.data.IterableDataset: - logger.info(f"Downloading dataset {dataset_name} from the HF Hub") - dataset_root = snapshot_download(dataset_name, repo_type="dataset") - if dataset_type == "image": - return ImageCaptionFilePairDataset(dataset_root, infinite=infinite) - else: - return VideoCaptionFilePairDataset(dataset_root, infinite=infinite) - - -def _initialize_data_file_caption_file_dataset_from_hub( - dataset_name: str, dataset_type: str, infinite: bool = False -) -> torch.utils.data.IterableDataset: - logger.info(f"Downloading dataset {dataset_name} from the HF Hub") - dataset_root = snapshot_download(dataset_name, repo_type="dataset") - if dataset_type == "image": - return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite) - else: - return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite) - - -def _initialize_webdataset( - dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None -) -> torch.utils.data.IterableDataset: - logger.info(f"Streaming webdataset {dataset_name} from the HF Hub") - _caption_options = _caption_options or {} - if dataset_type == "image": - return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options) - else: - return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options) - - -def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: - # TODO(aryan): this logic can be improved - if not remote: - caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0) - for caption_file in caption_files: - caption_file = pathlib.Path(caption_file) - for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: - data_filename = caption_file.with_suffix(f".{extension}") - if data_filename.exists(): - return True - return False - else: - caption_files = [file for file in root if file.endswith(".txt")] - for caption_file in caption_files: - caption_file = pathlib.Path(caption_file) - for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: - data_filename = caption_file.with_suffix(f".{extension}").name - if data_filename in root: - return True - return False - - -def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: - # TODO(aryan): this logic can be improved - if not remote: - file_list = {x.name for x in root.iterdir()} - has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES) - has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES) - has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES) - return has_caption_files and (has_video_files or has_image_files) - else: - has_caption_files = any(file in root for file in COMMON_CAPTION_FILES) - has_video_files = any(file in root for file in COMMON_VIDEO_FILES) - has_image_files = any(file in root for file in COMMON_IMAGE_FILES) - return has_caption_files and (has_video_files or has_image_files) - - -def _read_caption_from_file(filename: str) -> str: - with open(filename, "r") as f: - return f.read().strip() - - -def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor: - image = image.convert("RGB") - image = np.array(image).astype(np.float32) - image = torch.from_numpy(image) - image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0 - return image - - -def _preprocess_video(video: decord.VideoReader) -> torch.Tensor: - video = video.get_batch(list(range(len(video)))) - video = video.permute(0, 3, 1, 2).contiguous() - video = video.float() / 127.5 - 1.0 - return video diff --git a/finetrainers/data/precomputation.py b/finetrainers/data/precomputation.py deleted file mode 100644 index b325a648a2f35ab68537409d02b7505ff23c7a06..0000000000000000000000000000000000000000 --- a/finetrainers/data/precomputation.py +++ /dev/null @@ -1,376 +0,0 @@ -import pathlib -from typing import Any, Callable, Dict, Iterable, List, Optional, Union - -import torch -from tqdm.auto import tqdm - -from .. import utils -from ..logging import get_logger - - -logger = get_logger() - - -def initialize_preprocessor( - rank: int, - num_items: int, - processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], - save_dir: Optional[str] = None, - enable_precomputation: bool = False, -) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]: - if enable_precomputation: - return PrecomputedDistributedDataPreprocessor(rank, num_items, processor_fn, save_dir) - return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn) - - -class DistributedDataProcessorMixin: - def consume(self, *args, **kwargs): - raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.") - - def consume_once(self, *args, **kwargs): - raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.") - - @property - def requires_data(self): - raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.") - - -class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin): - def __init__( - self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] - ) -> None: - super().__init__() - - self._rank = rank - self._num_items = num_items - self._processor_fn = processor_fn - - self._cached_samples = [] - self._buffer = InMemoryDataBuffer(num_items) - self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None - - def consume( - self, - data_type: str, - components: Dict[str, Any], - data_iterator, - generator: Optional[torch.Generator] = None, - cache_samples: bool = False, - use_cached_samples: bool = False, - drop_samples: bool = False, - ) -> Iterable[Dict[str, Any]]: - if data_type not in self._processor_fn.keys(): - raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") - if cache_samples: - if use_cached_samples: - raise ValueError("Cannot cache and use cached samples at the same time.") - if drop_samples: - raise ValueError("Cannot cache and drop samples at the same time.") - - for i in range(self._num_items): - if use_cached_samples: - item = self._cached_samples[i] - else: - item = next(data_iterator) - if cache_samples: - self._cached_samples.append(item) - item = self._processor_fn[data_type](**item, **components, generator=generator) - self._buffer.add(data_type, item) - - if drop_samples: - del self._cached_samples - self._cached_samples = [] - - self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer) - return iter(self._preprocessed_iterator) - - def consume_once( - self, - data_type: str, - components: Dict[str, Any], - data_iterator, - generator: Optional[torch.Generator] = None, - cache_samples: bool = False, - use_cached_samples: bool = False, - drop_samples: bool = False, - ) -> Iterable[Dict[str, Any]]: - if data_type not in self._processor_fn.keys(): - raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") - if cache_samples: - if use_cached_samples: - raise ValueError("Cannot cache and use cached samples at the same time.") - if drop_samples: - raise ValueError("Cannot cache and drop samples at the same time.") - - for i in range(self._num_items): - if use_cached_samples: - item = self._cached_samples[i] - else: - item = next(data_iterator) - if cache_samples: - self._cached_samples.append(item) - item = self._processor_fn[data_type](**item, **components, generator=generator) - self._buffer.add(data_type, item) - - if drop_samples: - del self._cached_samples - self._cached_samples = [] - - self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer) - return iter(self._preprocessed_iterator) - - @property - def requires_data(self): - if self._preprocessed_iterator is None: - return True - return self._preprocessed_iterator.requires_data - - -class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin): - def __init__( - self, - rank: int, - num_items: int, - processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], - save_dir: str, - ) -> None: - super().__init__() - - self._rank = rank - self._num_items = num_items - self._processor_fn = processor_fn - self._save_dir = pathlib.Path(save_dir) - - self._cached_samples = [] - self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None - - self._save_dir.mkdir(parents=True, exist_ok=True) - - subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()] - utils.delete_files(subdirectories) - - def consume( - self, - data_type: str, - components: Dict[str, Any], - data_iterator, - generator: Optional[torch.Generator] = None, - cache_samples: bool = False, - use_cached_samples: bool = False, - drop_samples: bool = False, - ) -> Iterable[Dict[str, Any]]: - if data_type not in self._processor_fn.keys(): - raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") - if cache_samples: - if use_cached_samples: - raise ValueError("Cannot cache and use cached samples at the same time.") - if drop_samples: - raise ValueError("Cannot cache and drop samples at the same time.") - - for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items): - if use_cached_samples: - item = self._cached_samples[i] - else: - item = next(data_iterator) - if cache_samples: - self._cached_samples.append(item) - item = self._processor_fn[data_type](**item, **components, generator=generator) - _save_item(self._rank, i, item, self._save_dir, data_type) - - if drop_samples: - del self._cached_samples - self._cached_samples = [] - - self._preprocessed_iterator = PrecomputedDataIterable(self._rank, self._save_dir, data_type) - return iter(self._preprocessed_iterator) - - def consume_once( - self, - data_type: str, - components: Dict[str, Any], - data_iterator, - generator: Optional[torch.Generator] = None, - cache_samples: bool = False, - use_cached_samples: bool = False, - drop_samples: bool = False, - ) -> Iterable[Dict[str, Any]]: - if data_type not in self._processor_fn.keys(): - raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") - if cache_samples: - if use_cached_samples: - raise ValueError("Cannot cache and use cached samples at the same time.") - if drop_samples: - raise ValueError("Cannot cache and drop samples at the same time.") - - for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items): - if use_cached_samples: - item = self._cached_samples[i] - else: - item = next(data_iterator) - if cache_samples: - self._cached_samples.append(item) - item = self._processor_fn[data_type](**item, **components, generator=generator) - _save_item(self._rank, i, item, self._save_dir, data_type) - - if drop_samples: - del self._cached_samples - self._cached_samples = [] - - self._preprocessed_iterator = PrecomputedOnceDataIterable(self._rank, self._save_dir, data_type) - return iter(self._preprocessed_iterator) - - @property - def requires_data(self): - if self._preprocessed_iterator is None: - return True - return self._preprocessed_iterator.requires_data - - -class InMemoryDataIterable: - """ - An iterator that loads data items from an in-memory buffer. Once all the data is consumed, - `requires_data` is set to True, indicating that the more data is required and the preprocessor's - consume method should be called again. - """ - - def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None: - self._rank = rank - self._data_type = data_type - self._buffer = buffer - - self._requires_data = False - - def __iter__(self) -> Iterable[Dict[str, Any]]: - while (length := self._buffer.get_length(self._data_type)) > 0: - if length <= 1: - self._requires_data = True - yield self._buffer.get(self._data_type) - - def __len__(self) -> int: - return self._buffer.get_length(self._data_type) - - @property - def requires_data(self): - return self._requires_data - - -class InMemoryOnceDataIterable: - """ - An iterator that loads data items from an in-memory buffer. This iterator will never set - `requires_data` to True, as it is assumed that all the data was configured to be preprocessed - by the user. The data will indefinitely be cycled from the buffer. - """ - - def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None: - self._rank = rank - self._data_type = data_type - self._buffer = buffer - - self._requires_data = False - - def __iter__(self) -> Iterable[Dict[str, Any]]: - assert len(self) > 0, "No data available in the buffer." - while True: - item = self._buffer.get(self._data_type) - yield item - self._buffer.add(self._data_type, item) - - def __len__(self) -> int: - return self._buffer.get_length(self._data_type) - - @property - def requires_data(self): - return self._requires_data - - -class PrecomputedDataIterable: - """ - An iterator that loads preconfigured number of data items from disk. Once all the data is - loaded, `requires_data` is set to True, indicating that the more data is required and - the preprocessor's consume method should be called again. - """ - - def __init__(self, rank: int, save_dir: str, data_type: str) -> None: - self._rank = rank - self._save_dir = pathlib.Path(save_dir) - self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) - self._data_type = data_type - - self._requires_data = False - - def __iter__(self) -> Iterable[Dict[str, Any]]: - for i in range(self._num_items): - if i == self._num_items - 1: - self._requires_data = True - yield _load_item(self._rank, i, self._save_dir, self._data_type) - - def __len__(self) -> int: - return self._num_items - - @property - def requires_data(self): - return self._requires_data - - -class PrecomputedOnceDataIterable: - """ - An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator - will never set `requires_data` to True, as it is assumed that all the data was configured to - be preprocessed by the user. - """ - - def __init__(self, rank: int, save_dir: str, data_type: str) -> None: - self._rank = rank - self._save_dir = pathlib.Path(save_dir) - self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) - self._data_type = data_type - - self._requires_data = False - - def __iter__(self) -> Iterable[Dict[str, Any]]: - index = 0 - while True: - yield _load_item(self._rank, index, self._save_dir, self._data_type) - index = (index + 1) % self._num_items - - def __len__(self) -> int: - return self._num_items - - @property - def requires_data(self): - return self._requires_data - - -class InMemoryDataBuffer: - def __init__(self, max_limit: int = -1) -> None: - self.max_limit = max_limit - self.buffer: Dict[str, List[str]] = {} - - def add(self, data_type: str, item: Dict[str, Any]) -> None: - if data_type not in self.buffer: - self.buffer[data_type] = [] - if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit: - logger.log_freq( - "WARN", - "IN_MEMORY_DATA_BUFFER_FULL", - "Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.", - 64, - ) - self.buffer[data_type].pop(0) - self.buffer[data_type].append(item) - - def get(self, data_type: str) -> Dict[str, Any]: - return self.buffer[data_type].pop(0) - - def get_length(self, data_type: str) -> int: - return len(self.buffer[data_type]) - - -def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None: - filename = directory / f"{data_type}-{rank}-{index}.pt" - torch.save(item, filename.as_posix()) - - -def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]: - filename = directory / f"{data_type}-{rank}-{index}.pt" - return torch.load(filename.as_posix(), weights_only=True) diff --git a/finetrainers/data/sampler.py b/finetrainers/data/sampler.py deleted file mode 100644 index 5d9d650e1d610e8ce91b4168a9960479cfcfe8f7..0000000000000000000000000000000000000000 --- a/finetrainers/data/sampler.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any, Dict, List, Tuple - -import torch - - -class ResolutionSampler: - def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None: - self.batch_size = batch_size - self.dim_keys = dim_keys - assert dim_keys is not None, "dim_keys must be provided" - - self._chosen_leader_key = None - self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {} - self._satisfied_buckets: List[Dict[Any, Any]] = [] - - def consume(self, *dict_items: Dict[Any, Any]) -> None: - if self._chosen_leader_key is None: - self._determine_leader_item(*dict_items) - self._update_buckets(*dict_items) - - def get_batch(self) -> List[Dict[str, Any]]: - return list(zip(*self._satisfied_buckets.pop(-1))) - - @property - def is_ready(self) -> bool: - return len(self._satisfied_buckets) > 0 - - def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None: - num_observed = 0 - for dict_item in dict_items: - for key in self.dim_keys.keys(): - if key in dict_item.keys(): - self._chosen_leader_key = key - if not torch.is_tensor(dict_item[key]): - raise ValueError(f"Leader key {key} must be a tensor") - num_observed += 1 - if num_observed > 1: - raise ValueError( - f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys" - ) - if self._chosen_leader_key is None: - raise ValueError("No leader key found in provided list of data dictionaries") - - def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None: - chosen_value = [ - dict_item[self._chosen_leader_key] - for dict_item in dict_items - if self._chosen_leader_key in dict_item.keys() - ] - if len(chosen_value) == 0: - raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries") - chosen_value = chosen_value[0] - dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key]) - if dims not in self._unsatisfied_buckets: - self._unsatisfied_buckets[dims] = [] - self._unsatisfied_buckets[dims].append(dict_items) - if len(self._unsatisfied_buckets[dims]) == self.batch_size: - self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims)) diff --git a/finetrainers/data/utils.py b/finetrainers/data/utils.py deleted file mode 100644 index 4bd507f348efc0e123532e8082502c7cfad956ed..0000000000000000000000000000000000000000 --- a/finetrainers/data/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import pathlib -from typing import List - - -def find_files(root: str, pattern: str, depth: int = 0) -> List[str]: - root_path = pathlib.Path(root) - result_files = [] - - def within_depth(path: pathlib.Path) -> bool: - return len(path.relative_to(root_path).parts) <= depth - - if depth == 0: - result_files.extend([str(file) for file in root_path.glob(pattern)]) - else: - # rglob matches all levels, but we filter by depth - for file in root_path.rglob(pattern): - if file.is_file() and within_depth(file.parent): - result_files.append(str(file)) - - return result_files diff --git a/finetrainers/functional/__init__.py b/finetrainers/functional/__init__.py deleted file mode 100644 index a62a87847ac0e61521a2331ec3b9ea08cbd49abb..0000000000000000000000000000000000000000 --- a/finetrainers/functional/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .diffusion import flow_match_target, flow_match_xt -from .image import ( - bicubic_resize_image, - center_crop_image, - find_nearest_resolution_image, - resize_crop_image, - resize_to_nearest_bucket_image, -) -from .text import dropout_caption, dropout_embeddings_to_zero, remove_prefix -from .video import ( - bicubic_resize_video, - center_crop_video, - find_nearest_video_resolution, - resize_crop_video, - resize_to_nearest_bucket_video, -) diff --git a/finetrainers/functional/diffusion.py b/finetrainers/functional/diffusion.py deleted file mode 100644 index f9d553895c2fb251abf80f01f284049acf84f87d..0000000000000000000000000000000000000000 --- a/finetrainers/functional/diffusion.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - - -def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - r"""Forward process of flow matching.""" - return (1.0 - t) * x0 + t * n - - -def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: - r"""Loss target for flow matching.""" - return n - x0 diff --git a/finetrainers/functional/image.py b/finetrainers/functional/image.py deleted file mode 100644 index 8d966625ad50a8ca8f5d55e5c522ae5a31652113..0000000000000000000000000000000000000000 --- a/finetrainers/functional/image.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import List, Literal, Tuple - -import torch -import torch.nn.functional as F - - -def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - num_channels, height, width = image.shape - crop_h, crop_w = size - top = (height - crop_h) // 2 - left = (width - crop_w) // 2 - return image[:, top : top + crop_h, left : left + crop_w] - - -def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - num_channels, height, width = image.shape - target_h, target_w = size - scale = max(target_h / height, target_w / width) - new_h, new_w = int(height * scale), int(width * scale) - image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False) - return center_crop_image(image, size) - - -def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0] - - -def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]: - num_channels, height, width = image.shape - aspect_ratio = width / height - - def aspect_ratio_diff(bucket): - return abs((bucket[1] / bucket[0]) - aspect_ratio) - - return min(resolution_buckets, key=aspect_ratio_diff) - - -def resize_to_nearest_bucket_image( - image: torch.Tensor, - resolution_buckets: List[Tuple[int, int]], - resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", -) -> torch.Tensor: - target_size = find_nearest_resolution_image(image, resolution_buckets) - - if resize_mode == "center_crop": - return center_crop_image(image, target_size) - elif resize_mode == "resize_crop": - return resize_crop_image(image, target_size) - elif resize_mode == "bicubic": - return bicubic_resize_image(image, target_size) - else: - raise ValueError( - f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." - ) diff --git a/finetrainers/functional/text.py b/finetrainers/functional/text.py deleted file mode 100644 index 6e823edfc2e3f4a93d2afddf6df71a4198f05219..0000000000000000000000000000000000000000 --- a/finetrainers/functional/text.py +++ /dev/null @@ -1,26 +0,0 @@ -import random -from typing import List, Union - -import torch - - -def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]: - if random.random() >= dropout_p: - return caption - if isinstance(caption, str): - return "" - return [""] * len(caption) - - -def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor: - if random.random() >= dropout_p: - return embed - embed = torch.zeros_like(embed) - return embed - - -def remove_prefix(text: str, prefixes: List[str]) -> str: - for prefix in prefixes: - if text.startswith(prefix): - return text.removeprefix(prefix).strip() - return text diff --git a/finetrainers/functional/video.py b/finetrainers/functional/video.py deleted file mode 100644 index fcbc382b0615e53270f3b17746fec14f438ddd16..0000000000000000000000000000000000000000 --- a/finetrainers/functional/video.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import List, Literal, Tuple - -import torch -import torch.nn.functional as F - - -def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - num_frames, num_channels, height, width = video.shape - crop_h, crop_w = size - top = (height - crop_h) // 2 - left = (width - crop_w) // 2 - return video[:, :, top : top + crop_h, left : left + crop_w] - - -def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - num_frames, num_channels, height, width = video.shape - target_h, target_w = size - scale = max(target_h / height, target_w / width) - new_h, new_w = int(height * scale), int(width * scale) - video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False) - return center_crop_video(video, size) - - -def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - num_frames, num_channels, height, width = video.shape - video = F.interpolate(video, size=size, mode="bicubic", align_corners=False) - return video - - -def find_nearest_video_resolution( - video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]] -) -> Tuple[int, int, int]: - num_frames, num_channels, height, width = video.shape - aspect_ratio = width / height - possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames] - - if not possible_buckets: - best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames)) - else: - best_frame_match = max(possible_buckets, key=lambda b: b[0]) - - frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]] - - def aspect_ratio_diff(bucket): - return abs((bucket[2] / bucket[1]) - aspect_ratio) - - return min(frame_filtered_buckets, key=aspect_ratio_diff) - - -def resize_to_nearest_bucket_video( - video: torch.Tensor, - resolution_buckets: List[Tuple[int, int, int]], - resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", -) -> torch.Tensor: - """ - Resizes a video tensor to the nearest resolution bucket using the specified mode. - - It first finds a frame match with <= T frames. - - Then, it selects the closest height/width bucket. - - Args: - video (`torch.Tensor`): - Input video tensor of shape `(B, T, C, H, W)`. - resolution_buckets (`List[Tuple[int, int, int]]`): - Available (num_frames, height, width) resolution buckets. - resize_mode (`str`): - One of ["center_crop", "resize_crop", "bicubic"]. - - Returns: - `torch.Tensor`: - Resized video tensor of the nearest bucket resolution. - """ - target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets) - - # Adjust frame count: only interpolate frames if no lesser/equal frame count exists - num_frames, num_channels, height, width = video.shape - _first_frame_only = False - if num_frames > target_frames: - # Downsample: Select frames evenly - indices = torch.linspace(0, num_frames - 1, target_frames).long() - video = video[indices, :, :, :] - elif num_frames < target_frames: - _first_frame_only = False - - # Resize spatial resolution - if resize_mode == "center_crop": - return center_crop_video(video, (target_h, target_w)), _first_frame_only - elif resize_mode == "resize_crop": - return resize_crop_video(video, (target_h, target_w)), _first_frame_only - elif resize_mode == "bicubic": - return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only - else: - raise ValueError( - f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." - ) diff --git a/finetrainers/logging.py b/finetrainers/logging.py deleted file mode 100644 index 29d3597db47a91c1f3ec353a677c7309ceffc19d..0000000000000000000000000000000000000000 --- a/finetrainers/logging.py +++ /dev/null @@ -1,111 +0,0 @@ -import logging -import os -from typing import TYPE_CHECKING, Union - -from .constants import FINETRAINERS_LOG_LEVEL - - -if TYPE_CHECKING: - from .parallel import ParallelBackendType - - -class FinetrainersLoggerAdapter(logging.LoggerAdapter): - def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None: - super().__init__(logger, {}) - self.parallel_backend = parallel_backend - self._log_freq = {} - self._log_freq_counter = {} - - def log( - self, - level, - msg, - *args, - main_process_only: bool = False, - local_main_process_only: bool = True, - in_order: bool = False, - **kwargs, - ): - # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice - kwargs.setdefault("stacklevel", 2) - - if not self.isEnabledFor(level): - return - - if self.parallel_backend is None: - if int(os.environ.get("RANK", 0)) == 0: - msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, msg, *args, **kwargs) - return - - if (main_process_only or local_main_process_only) and in_order: - raise ValueError( - "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True." - ) - - if (main_process_only and self.parallel_backend.is_main_process) or ( - local_main_process_only and self.parallel_backend.is_local_main_process - ): - msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, msg, *args, **kwargs) - return - - if in_order: - for i in range(self.parallel_backend.world_size): - if self.rank == i: - msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, msg, *args, **kwargs) - self.parallel_backend.wait_for_everyone() - return - - if not main_process_only and not local_main_process_only: - msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, msg, *args, **kwargs) - return - - def log_freq( - self, - level: str, - name: str, - msg: str, - frequency: int, - *, - main_process_only: bool = False, - local_main_process_only: bool = True, - in_order: bool = False, - **kwargs, - ) -> None: - if frequency <= 0: - return - if name not in self._log_freq_counter: - self._log_freq[name] = frequency - self._log_freq_counter[name] = 0 - if self._log_freq_counter[name] % self._log_freq[name] == 0: - self.log( - level, - msg, - main_process_only=main_process_only, - local_main_process_only=local_main_process_only, - in_order=in_order, - **kwargs, - ) - self._log_freq_counter[name] += 1 - - -def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]: - global _logger - return _logger - - -def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter: - _logger.parallel_backend = parallel_backend - - -_logger = logging.getLogger("finetrainers") -_logger.setLevel(FINETRAINERS_LOG_LEVEL) -_console_handler = logging.StreamHandler() -_console_handler.setLevel(FINETRAINERS_LOG_LEVEL) -_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -_console_handler.setFormatter(_formatter) -_logger.addHandler(_console_handler) -_logger = FinetrainersLoggerAdapter(_logger) diff --git a/finetrainers/models/__init__.py b/finetrainers/models/__init__.py deleted file mode 100644 index fb7091a5e1650715591fdd7377e7c2850c0e3bb3..0000000000000000000000000000000000000000 --- a/finetrainers/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .modeling_utils import ModelSpecification diff --git a/finetrainers/models/cogvideox/__init__.py b/finetrainers/models/cogvideox/__init__.py deleted file mode 100644 index e1f9a84073541b0e764877bac0335637f03d32ca..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogvideox/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_specification import CogVideoXModelSpecification diff --git a/finetrainers/models/cogvideox/base_specification.py b/finetrainers/models/cogvideox/base_specification.py deleted file mode 100644 index 9580f6b0d12d4a7613b72ed7523dd712a99d255e..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogvideox/base_specification.py +++ /dev/null @@ -1,423 +0,0 @@ -import os -from typing import Any, Dict, List, Optional, Tuple - -import torch -from accelerate import init_empty_weights -from diffusers import ( - AutoencoderKLCogVideoX, - CogVideoXDDIMScheduler, - CogVideoXImageToVideoPipeline, - CogVideoXPipeline, - CogVideoXTransformer3DModel, -) -from PIL.Image import Image -from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer - -from ... import data -from ...logging import get_logger -from ...processors import ProcessorMixin, T5Processor -from ...typing import ArtifactType, SchedulerType -from ...utils import get_non_null_items -from ..modeling_utils import ModelSpecification -from ..utils import DiagonalGaussianDistribution -from .utils import prepare_rotary_positional_embeddings - - -logger = get_logger() - - -class CogVideoXLatentEncodeProcessor(ProcessorMixin): - r""" - Processor to encode image/video into latents using the CogVideoX VAE. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor returns. The outputs are in the following order: - - latents: The latents of the input image/video. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - self.output_names = output_names - assert len(self.output_names) == 1 - - def forward( - self, - vae: AutoencoderKLCogVideoX, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - ) -> Dict[str, torch.Tensor]: - device = vae.device - dtype = vae.dtype - - if image is not None: - video = image.unsqueeze(1) - - assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" - video = video.to(device=device, dtype=vae.dtype) - video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] - - if compute_posterior: - latents = vae.encode(video).latent_dist.sample(generator=generator) - latents = latents.to(dtype=dtype) - else: - if vae.use_slicing and video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] - moments = torch.cat(encoded_slices) - else: - moments = vae._encode(video) - latents = moments.to(dtype=dtype) - - latents = latents.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] -> [B, F, C, H, W] - return {self.output_names[0]: latents} - - -class CogVideoXModelSpecification(ModelSpecification): - def __init__( - self, - pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b", - tokenizer_id: Optional[str] = None, - text_encoder_id: Optional[str] = None, - transformer_id: Optional[str] = None, - vae_id: Optional[str] = None, - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - condition_model_processors: List[ProcessorMixin] = None, - latent_model_processors: List[ProcessorMixin] = None, - **kwargs, - ) -> None: - super().__init__( - pretrained_model_name_or_path=pretrained_model_name_or_path, - tokenizer_id=tokenizer_id, - text_encoder_id=text_encoder_id, - transformer_id=transformer_id, - vae_id=vae_id, - text_encoder_dtype=text_encoder_dtype, - transformer_dtype=transformer_dtype, - vae_dtype=vae_dtype, - revision=revision, - cache_dir=cache_dir, - ) - - if condition_model_processors is None: - condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])] - if latent_model_processors is None: - latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])] - - self.condition_model_processors = condition_model_processors - self.latent_model_processors = latent_model_processors - - @property - def _resolution_dim_keys(self): - return {"latents": (1, 3, 4)} - - def load_condition_models(self) -> Dict[str, torch.nn.Module]: - if self.tokenizer_id is not None: - tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - tokenizer = T5Tokenizer.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.text_encoder_id is not None: - text_encoder = AutoModel.from_pretrained( - self.text_encoder_id, - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - text_encoder = T5EncoderModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"tokenizer": tokenizer, "text_encoder": text_encoder} - - def load_latent_models(self) -> Dict[str, torch.nn.Module]: - if self.vae_id is not None: - vae = AutoencoderKLCogVideoX.from_pretrained( - self.vae_id, - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - vae = AutoencoderKLCogVideoX.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"vae": vae} - - def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: - if self.transformer_id is not None: - transformer = CogVideoXTransformer3DModel.from_pretrained( - self.transformer_id, - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - transformer = CogVideoXTransformer3DModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - scheduler = CogVideoXDDIMScheduler.from_pretrained( - self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir - ) - - return {"transformer": transformer, "scheduler": scheduler} - - def load_pipeline( - self, - tokenizer: Optional[T5Tokenizer] = None, - text_encoder: Optional[T5EncoderModel] = None, - transformer: Optional[CogVideoXTransformer3DModel] = None, - vae: Optional[AutoencoderKLCogVideoX] = None, - scheduler: Optional[CogVideoXDDIMScheduler] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - training: bool = False, - **kwargs, - ) -> CogVideoXPipeline: - components = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - } - components = get_non_null_items(components) - - pipe = CogVideoXPipeline.from_pretrained( - self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir - ) - pipe.text_encoder.to(self.text_encoder_dtype) - pipe.vae.to(self.vae_dtype) - - if not training: - pipe.transformer.to(self.transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload() - - return pipe - - @torch.no_grad() - def prepare_conditions( - self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - caption: str, - max_sequence_length: int = 226, - **kwargs, - ) -> Dict[str, Any]: - conditions = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "caption": caption, - "max_sequence_length": max_sequence_length, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_conditions(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - conditions.pop("prompt_attention_mask", None) - return conditions - - @torch.no_grad() - def prepare_latents( - self, - vae: AutoencoderKLCogVideoX, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Dict[str, torch.Tensor]: - conditions = { - "vae": vae, - "image": image, - "video": video, - "generator": generator, - "compute_posterior": compute_posterior, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_latents(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - def forward( - self, - transformer: CogVideoXTransformer3DModel, - scheduler: CogVideoXDDIMScheduler, - condition_model_conditions: Dict[str, torch.Tensor], - latent_model_conditions: Dict[str, torch.Tensor], - sigmas: torch.Tensor, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Tuple[torch.Tensor, ...]: - # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself. - VAE_SPATIAL_SCALE_FACTOR = 8 - rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR - rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR - patch_size = self.transformer_config.patch_size - patch_size_t = getattr(self.transformer_config, "patch_size_t", None) - - if compute_posterior: - latents = latent_model_conditions.pop("latents") - else: - posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2) - latents = posterior.sample(generator=generator) - del posterior - - if not getattr(self.vae_config, "invert_scale_latents", False): - latents = latents * self.vae_config.scaling_factor - - if patch_size_t is not None: - latents = self._pad_frames(latents, patch_size_t) - - timesteps = (sigmas.flatten() * 1000.0).long() - - noise = torch.zeros_like(latents).normal_(generator=generator) - noisy_latents = scheduler.add_noise(latents, noise, timesteps) - - batch_size, num_frames, num_channels, height, width = latents.shape - ofs_emb = ( - None - if getattr(self.transformer_config, "ofs_embed_dim", None) is None - else latents.new_full((batch_size,), fill_value=2.0) - ) - - image_rotary_emb = ( - prepare_rotary_positional_embeddings( - height=height * VAE_SPATIAL_SCALE_FACTOR, - width=width * VAE_SPATIAL_SCALE_FACTOR, - num_frames=num_frames, - vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR, - patch_size=patch_size, - patch_size_t=patch_size_t, - attention_head_dim=self.transformer_config.attention_head_dim, - device=transformer.device, - base_height=rope_base_height, - base_width=rope_base_width, - ) - if self.transformer_config.use_rotary_positional_embeddings - else None - ) - - latent_model_conditions["hidden_states"] = noisy_latents.to(latents) - latent_model_conditions["image_rotary_emb"] = image_rotary_emb - latent_model_conditions["ofs"] = ofs_emb - - velocity = transformer( - **latent_model_conditions, - **condition_model_conditions, - timestep=timesteps, - return_dict=False, - )[0] - # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same - # code paths as scheduler.get_velocity(), which can be confusing to understand. - pred = scheduler.get_velocity(velocity, noisy_latents, timesteps) - target = latents - - return pred, target, sigmas - - def validation( - self, - pipeline: CogVideoXPipeline, - prompt: str, - image: Optional[Image] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - **kwargs, - ) -> List[ArtifactType]: - # TODO(aryan): add support for more parameters - if image is not None: - pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline) - - generation_kwargs = { - "prompt": prompt, - "image": image, - "height": height, - "width": width, - "num_frames": num_frames, - "num_inference_steps": num_inference_steps, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = get_non_null_items(generation_kwargs) - video = pipeline(**generation_kwargs).frames[0] - return [data.VideoArtifact(value=video)] - - def _save_lora_weights( - self, - directory: str, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - *args, - **kwargs, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - CogVideoXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - def _save_model( - self, - directory: str, - transformer: CogVideoXTransformer3DModel, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - with init_empty_weights(): - transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config) - transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) - transformer_copy.save_pretrained(os.path.join(directory, "transformer")) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - @staticmethod - def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor: - num_frames = latents.size(1) - additional_frames = patch_size_t - (num_frames % patch_size_t) - if additional_frames > 0: - last_frame = latents[:, -1:] - padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1) - latents = torch.cat([latents, padding_frames], dim=1) - return latents diff --git a/finetrainers/models/cogvideox/utils.py b/finetrainers/models/cogvideox/utils.py deleted file mode 100644 index bd98c1f3653dbe23a6f53fa54dfe3e7073ea9b99..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogvideox/utils.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional, Tuple - -import torch -from diffusers.models.embeddings import get_3d_rotary_pos_embed -from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid - - -def prepare_rotary_positional_embeddings( - height: int, - width: int, - num_frames: int, - vae_scale_factor_spatial: int = 8, - patch_size: int = 2, - patch_size_t: int = None, - attention_head_dim: int = 64, - device: Optional[torch.device] = None, - base_height: int = 480, - base_width: int = 720, -) -> Tuple[torch.Tensor, torch.Tensor]: - grid_height = height // (vae_scale_factor_spatial * patch_size) - grid_width = width // (vae_scale_factor_spatial * patch_size) - base_size_width = base_width // (vae_scale_factor_spatial * patch_size) - base_size_height = base_height // (vae_scale_factor_spatial * patch_size) - - if patch_size_t is None: - # CogVideoX 1.0 - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=num_frames, - ) - else: - # CogVideoX 1.5 - base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t - - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=attention_head_dim, - crops_coords=None, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - grid_type="slice", - max_size=(base_size_height, base_size_width), - ) - - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin diff --git a/finetrainers/models/cogview4/__init__.py b/finetrainers/models/cogview4/__init__.py deleted file mode 100644 index d3e63dd7174fc4692d89acfaa2e119ac104c43df..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogview4/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_specification import CogView4ModelSpecification diff --git a/finetrainers/models/cogview4/base_specification.py b/finetrainers/models/cogview4/base_specification.py deleted file mode 100644 index 7f311f15af46adf9f7db3daa745028824e964e5a..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogview4/base_specification.py +++ /dev/null @@ -1,395 +0,0 @@ -import os -from typing import Any, Dict, List, Optional, Tuple - -import torch -from accelerate import init_empty_weights -from diffusers import ( - AutoencoderKL, - CogView4Pipeline, - CogView4Transformer2DModel, - FlowMatchEulerDiscreteScheduler, -) -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from transformers import AutoTokenizer, GlmModel - -from ... import data -from ... import functional as FF -from ...logging import get_logger -from ...processors import CogView4GLMProcessor, ProcessorMixin -from ...typing import ArtifactType, SchedulerType -from ...utils import get_non_null_items -from ..modeling_utils import ModelSpecification - - -logger = get_logger() - - -class CogView4LatentEncodeProcessor(ProcessorMixin): - r""" - Processor to encode image/video into latents using the LTX VAE. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor returns. The outputs are in the following order: - - latents: The latents of the input image/video. - - original_size: The original size of the input image/video. - - target_size: The target size of the input image/video. - - crop_coords: The top-left crop coordinates of the input image/video. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - - self.output_names = output_names - assert len(self.output_names) == 4 - - def forward( - self, - vae: AutoencoderKL, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - _original_height: Optional[int] = None, - _original_width: Optional[int] = None, - ) -> Dict[str, torch.Tensor]: - device = vae.device - dtype = vae.dtype - - if video is not None: - # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly - image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W] - - assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor" - image = image.to(device=device, dtype=vae.dtype) - - if compute_posterior: - latents = vae.encode(image).latent_dist.sample(generator=generator) - latents = latents.to(dtype=dtype) - else: - if vae.use_slicing and image.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)] - moments = torch.cat(encoded_slices) - else: - moments = vae._encode(image) - latents = moments.to(dtype=dtype) - - batch_size = latents.size(0) - target_height = image.size(2) - target_width = image.size(3) - original_size = torch.tensor([(_original_height, _original_width)], device=device, dtype=dtype).repeat( - batch_size, 1 - ) - target_size = torch.tensor([(target_height, target_width)], device=device, dtype=dtype).repeat(batch_size, 1) - crop_coords = torch.tensor([(0, 0)], device=device, dtype=dtype).repeat(batch_size, 1) - - return { - self.output_names[0]: latents, - self.output_names[1]: original_size, - self.output_names[2]: target_size, - self.output_names[3]: crop_coords, - } - - -class CogView4ModelSpecification(ModelSpecification): - def __init__( - self, - pretrained_model_name_or_path: str = "THUDM/CogView4-6B", - tokenizer_id: Optional[str] = None, - text_encoder_id: Optional[str] = None, - transformer_id: Optional[str] = None, - vae_id: Optional[str] = None, - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - condition_model_processors: List[ProcessorMixin] = None, - latent_model_processors: List[ProcessorMixin] = None, - **kwargs, - ) -> None: - super().__init__( - pretrained_model_name_or_path=pretrained_model_name_or_path, - tokenizer_id=tokenizer_id, - text_encoder_id=text_encoder_id, - transformer_id=transformer_id, - vae_id=vae_id, - text_encoder_dtype=text_encoder_dtype, - transformer_dtype=transformer_dtype, - vae_dtype=vae_dtype, - revision=revision, - cache_dir=cache_dir, - ) - - if condition_model_processors is None: - condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])] - if latent_model_processors is None: - latent_model_processors = [ - CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"]) - ] - - self.condition_model_processors = condition_model_processors - self.latent_model_processors = latent_model_processors - - @property - def _resolution_dim_keys(self): - return {"latents": (2, 3)} - - def load_condition_models(self) -> Dict[str, torch.nn.Module]: - if self.tokenizer_id is not None: - tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.text_encoder_id is not None: - text_encoder = GlmModel.from_pretrained( - self.text_encoder_id, - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - text_encoder = GlmModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"tokenizer": tokenizer, "text_encoder": text_encoder} - - def load_latent_models(self) -> Dict[str, torch.nn.Module]: - if self.vae_id is not None: - vae = AutoencoderKL.from_pretrained( - self.vae_id, - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - vae = AutoencoderKL.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"vae": vae} - - def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: - if self.transformer_id is not None: - transformer = CogView4Transformer2DModel.from_pretrained( - self.transformer_id, - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - transformer = CogView4Transformer2DModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return {"transformer": transformer, "scheduler": scheduler} - - def load_pipeline( - self, - tokenizer: Optional[AutoTokenizer] = None, - text_encoder: Optional[GlmModel] = None, - transformer: Optional[CogView4Transformer2DModel] = None, - vae: Optional[AutoencoderKL] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - training: bool = False, - **kwargs, - ) -> CogView4Pipeline: - components = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "transformer": transformer, - "vae": vae, - # Load the scheduler based on CogView4's config instead of using the default initialization being used for training - # "scheduler": scheduler, - } - components = get_non_null_items(components) - - pipe = CogView4Pipeline.from_pretrained( - self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir - ) - pipe.text_encoder.to(self.text_encoder_dtype) - pipe.vae.to(self.vae_dtype) - - if not training: - pipe.transformer.to(self.transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload() - - return pipe - - @torch.no_grad() - def prepare_conditions( - self, - tokenizer: AutoTokenizer, - text_encoder: GlmModel, - caption: str, - max_sequence_length: int = 1024, - **kwargs, - ) -> Dict[str, Any]: - conditions = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "caption": caption, - "max_sequence_length": max_sequence_length, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_conditions(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - @torch.no_grad() - def prepare_latents( - self, - vae: AutoencoderKL, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - _original_height: Optional[int] = None, - _original_width: Optional[int] = None, - **kwargs, - ) -> Dict[str, torch.Tensor]: - conditions = { - "vae": vae, - "image": image, - "video": video, - "generator": generator, - "compute_posterior": compute_posterior, - "_original_height": _original_height, - "_original_width": _original_width, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_latents(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - def forward( - self, - transformer: CogView4Transformer2DModel, - condition_model_conditions: Dict[str, torch.Tensor], - latent_model_conditions: Dict[str, torch.Tensor], - sigmas: torch.Tensor, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Tuple[torch.Tensor, ...]: - if compute_posterior: - latents = latent_model_conditions.pop("latents") - else: - posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) - latents = posterior.sample(generator=generator) - del posterior - - latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor - noise = torch.zeros_like(latents).normal_(generator=generator) - timesteps = (sigmas.flatten() * 1000.0).long() - - base_image_sequence_length = 256 - base_shift = 0.25 - max_shift = 0.75 - - image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2 - mu = (image_sequence_length / base_image_sequence_length) ** 0.5 - mu = mu * max_shift + base_shift - shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0) - noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas) - - latent_model_conditions["hidden_states"] = noisy_latents.to(latents) - - pred = transformer( - **latent_model_conditions, - **condition_model_conditions, - timestep=timesteps, - return_dict=False, - )[0] - target = FF.flow_match_target(noise, latents) - - # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation - # but let's keep it this way for now. Longer training runs should reveal more insights. - # return pred, target, sigmas - return pred, target, shifted_sigmas - - def validation( - self, - pipeline: CogView4Pipeline, - prompt: str, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - **kwargs, - ) -> List[ArtifactType]: - generation_kwargs = { - "prompt": prompt, - "height": height, - "width": width, - "num_inference_steps": num_inference_steps, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = get_non_null_items(generation_kwargs) - image = pipeline(**generation_kwargs).images[0] - return [data.ImageArtifact(value=image)] - - def _save_lora_weights( - self, - directory: str, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - *args, - **kwargs, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - CogView4Pipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - def _save_model( - self, - directory: str, - transformer: CogView4Transformer2DModel, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - with init_empty_weights(): - transformer_copy = CogView4Transformer2DModel.from_config(transformer.config) - transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) - transformer_copy.save_pretrained(os.path.join(directory, "transformer")) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/finetrainers/models/hunyuan_video/__init__.py b/finetrainers/models/hunyuan_video/__init__.py deleted file mode 100644 index 518a42865f0cee30a534da458ec63b08c1a8d7e4..0000000000000000000000000000000000000000 --- a/finetrainers/models/hunyuan_video/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_specification import HunyuanVideoModelSpecification diff --git a/finetrainers/models/hunyuan_video/base_specification.py b/finetrainers/models/hunyuan_video/base_specification.py deleted file mode 100644 index e72a060bf4deb8f8eec6c1112526149303ca33fe..0000000000000000000000000000000000000000 --- a/finetrainers/models/hunyuan_video/base_specification.py +++ /dev/null @@ -1,410 +0,0 @@ -import os -from typing import Any, Dict, List, Optional, Tuple - -import torch -from accelerate import init_empty_weights -from diffusers import ( - AutoencoderKLHunyuanVideo, - FlowMatchEulerDiscreteScheduler, - HunyuanVideoPipeline, - HunyuanVideoTransformer3DModel, -) -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel - -from ... import data -from ... import functional as FF -from ...logging import get_logger -from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin -from ...typing import ArtifactType, SchedulerType -from ...utils import get_non_null_items -from ..modeling_utils import ModelSpecification - - -logger = get_logger() - - -class HunyuanLatentEncodeProcessor(ProcessorMixin): - r""" - Processor to encode image/video into latents using the HunyuanVideo VAE. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor returns. The outputs are in the following order: - - latents: The latents of the input image/video. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - self.output_names = output_names - assert len(self.output_names) == 1 - - def forward( - self, - vae: AutoencoderKLHunyuanVideo, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - ) -> Dict[str, torch.Tensor]: - device = vae.device - dtype = vae.dtype - - if image is not None: - video = image.unsqueeze(1) - - assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" - video = video.to(device=device, dtype=vae.dtype) - video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] - - if compute_posterior: - latents = vae.encode(video).latent_dist.sample(generator=generator) - latents = latents.to(dtype=dtype) - else: - if vae.use_slicing and video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] - moments = torch.cat(encoded_slices) - else: - moments = vae._encode(video) - latents = moments.to(dtype=dtype) - - return {self.output_names[0]: latents} - - -class HunyuanVideoModelSpecification(ModelSpecification): - def __init__( - self, - pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo", - tokenizer_id: Optional[str] = None, - text_encoder_id: Optional[str] = None, - transformer_id: Optional[str] = None, - vae_id: Optional[str] = None, - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - condition_model_processors: List[ProcessorMixin] = None, - latent_model_processors: List[ProcessorMixin] = None, - **kwargs, - ) -> None: - super().__init__( - pretrained_model_name_or_path=pretrained_model_name_or_path, - tokenizer_id=tokenizer_id, - text_encoder_id=text_encoder_id, - transformer_id=transformer_id, - vae_id=vae_id, - text_encoder_dtype=text_encoder_dtype, - transformer_dtype=transformer_dtype, - vae_dtype=vae_dtype, - revision=revision, - cache_dir=cache_dir, - ) - - if condition_model_processors is None: - condition_model_processors = [ - LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]), - CLIPPooledProcessor( - ["pooled_projections"], - input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"}, - ), - ] - if latent_model_processors is None: - latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])] - - self.condition_model_processors = condition_model_processors - self.latent_model_processors = latent_model_processors - - @property - def _resolution_dim_keys(self): - return {"latents": (2, 3, 4)} - - def load_condition_models(self) -> Dict[str, torch.nn.Module]: - if self.tokenizer_id is not None: - tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.tokenizer_2_id is not None: - tokenizer_2 = CLIPTokenizer.from_pretrained( - self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - tokenizer_2 = CLIPTokenizer.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="tokenizer_2", - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.text_encoder_id is not None: - text_encoder = LlamaModel.from_pretrained( - self.text_encoder_id, - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - text_encoder = LlamaModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.text_encoder_2_id is not None: - text_encoder_2 = CLIPTextModel.from_pretrained( - self.text_encoder_2_id, - torch_dtype=self.text_encoder_2_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - text_encoder_2 = CLIPTextModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="text_encoder_2", - torch_dtype=self.text_encoder_2_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return { - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - } - - def load_latent_models(self) -> Dict[str, torch.nn.Module]: - if self.vae_id is not None: - vae = AutoencoderKLHunyuanVideo.from_pretrained( - self.vae_id, - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - vae = AutoencoderKLHunyuanVideo.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"vae": vae} - - def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: - if self.transformer_id is not None: - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - self.transformer_id, - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return {"transformer": transformer, "scheduler": scheduler} - - def load_pipeline( - self, - tokenizer: Optional[AutoTokenizer] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - text_encoder: Optional[LlamaModel] = None, - text_encoder_2: Optional[CLIPTextModel] = None, - transformer: Optional[HunyuanVideoTransformer3DModel] = None, - vae: Optional[AutoencoderKLHunyuanVideo] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - training: bool = False, - **kwargs, - ) -> HunyuanVideoPipeline: - components = { - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - } - components = get_non_null_items(components) - - pipe = HunyuanVideoPipeline.from_pretrained( - self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir - ) - pipe.text_encoder.to(self.text_encoder_dtype) - pipe.text_encoder_2.to(self.text_encoder_2_dtype) - pipe.vae.to(self.vae_dtype) - - if not training: - pipe.transformer.to(self.transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload() - - return pipe - - @torch.no_grad() - def prepare_conditions( - self, - tokenizer: AutoTokenizer, - tokenizer_2: CLIPTokenizer, - text_encoder: LlamaModel, - text_encoder_2: CLIPTextModel, - caption: str, - max_sequence_length: int = 256, - **kwargs, - ) -> Dict[str, Any]: - conditions = { - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "caption": caption, - "max_sequence_length": max_sequence_length, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_conditions(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - @torch.no_grad() - def prepare_latents( - self, - vae: AutoencoderKLHunyuanVideo, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Dict[str, torch.Tensor]: - conditions = { - "vae": vae, - "image": image, - "video": video, - "generator": generator, - "compute_posterior": compute_posterior, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_latents(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - def forward( - self, - transformer: HunyuanVideoTransformer3DModel, - condition_model_conditions: Dict[str, torch.Tensor], - latent_model_conditions: Dict[str, torch.Tensor], - sigmas: torch.Tensor, - guidance: float = 1.0, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Tuple[torch.Tensor, ...]: - if compute_posterior: - latents = latent_model_conditions.pop("latents") - else: - posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) - latents = posterior.sample(generator=generator) - del posterior - - latents = latents * self.vae_config.scaling_factor - noise = torch.zeros_like(latents).normal_(generator=generator) - noisy_latents = FF.flow_match_xt(latents, noise, sigmas) - - timesteps = (sigmas.flatten() * 1000.0).long() - guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0 - - latent_model_conditions["hidden_states"] = noisy_latents.to(latents) - latent_model_conditions["guidance"] = guidance - - pred = transformer( - **latent_model_conditions, - **condition_model_conditions, - timestep=timesteps, - return_dict=False, - )[0] - target = FF.flow_match_target(noise, latents) - - return pred, target, sigmas - - def validation( - self, - pipeline: HunyuanVideoPipeline, - prompt: str, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - **kwargs, - ) -> List[ArtifactType]: - generation_kwargs = { - "prompt": prompt, - "height": height, - "width": width, - "num_frames": num_frames, - "num_inference_steps": num_inference_steps, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = get_non_null_items(generation_kwargs) - video = pipeline(**generation_kwargs).frames[0] - return [data.VideoArtifact(value=video)] - - def _save_lora_weights( - self, - directory: str, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - *args, - **kwargs, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - HunyuanVideoPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - def _save_model( - self, - directory: str, - transformer: HunyuanVideoTransformer3DModel, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - with init_empty_weights(): - transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config) - transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) - transformer_copy.save_pretrained(os.path.join(directory, "transformer")) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/finetrainers/models/ltx_video/__init__.py b/finetrainers/models/ltx_video/__init__.py deleted file mode 100644 index ff4e3550d54bb33fac80dd2d075ad2846eeeed46..0000000000000000000000000000000000000000 --- a/finetrainers/models/ltx_video/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_specification import LTXVideoModelSpecification diff --git a/finetrainers/models/ltx_video/base_specification.py b/finetrainers/models/ltx_video/base_specification.py deleted file mode 100644 index d6a6be82298e156f290344eb62644af5de465b2a..0000000000000000000000000000000000000000 --- a/finetrainers/models/ltx_video/base_specification.py +++ /dev/null @@ -1,517 +0,0 @@ -import os -import random -from typing import Any, Dict, List, Optional, Tuple - -import torch -from accelerate import init_empty_weights -from diffusers import ( - AutoencoderKLLTXVideo, - FlowMatchEulerDiscreteScheduler, - LTXImageToVideoPipeline, - LTXPipeline, - LTXVideoTransformer3DModel, -) -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from PIL.Image import Image -from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer - -from ... import data -from ... import functional as FF -from ...logging import get_logger -from ...parallel import ParallelBackendEnum -from ...processors import ProcessorMixin, T5Processor -from ...typing import ArtifactType, SchedulerType -from ...utils import get_non_null_items -from ..modeling_utils import ModelSpecification - - -logger = get_logger() - - -class LTXLatentEncodeProcessor(ProcessorMixin): - r""" - Processor to encode image/video into latents using the LTX VAE. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor returns. The outputs are in the following order: - - latents: The latents of the input image/video. - - num_frames: The number of frames in the input video. - - height: The height of the input image/video. - - width: The width of the input image/video. - - latents_mean: The latent channel means from the VAE state dict. - - latents_std: The latent channel standard deviations from the VAE state dict. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - self.output_names = output_names - assert len(self.output_names) == 6 - - def forward( - self, - vae: AutoencoderKLLTXVideo, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - ) -> Dict[str, torch.Tensor]: - device = vae.device - dtype = vae.dtype - - if image is not None: - video = image.unsqueeze(1) - - assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" - video = video.to(device=device, dtype=vae.dtype) - video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] - - if compute_posterior: - latents = vae.encode(video).latent_dist.sample(generator=generator) - latents = latents.to(dtype=dtype) - else: - if vae.use_slicing and video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] - moments = torch.cat(encoded_slices) - else: - moments = vae._encode(video) - latents = moments.to(dtype=dtype) - - _, _, num_frames, height, width = latents.shape - - return { - self.output_names[0]: latents, - self.output_names[1]: num_frames, - self.output_names[2]: height, - self.output_names[3]: width, - self.output_names[4]: vae.latents_mean, - self.output_names[5]: vae.latents_std, - } - - -class LTXVideoModelSpecification(ModelSpecification): - def __init__( - self, - pretrained_model_name_or_path: str = "Lightricks/LTX-Video", - tokenizer_id: Optional[str] = None, - text_encoder_id: Optional[str] = None, - transformer_id: Optional[str] = None, - vae_id: Optional[str] = None, - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - condition_model_processors: List[ProcessorMixin] = None, - latent_model_processors: List[ProcessorMixin] = None, - **kwargs, - ) -> None: - super().__init__( - pretrained_model_name_or_path=pretrained_model_name_or_path, - tokenizer_id=tokenizer_id, - text_encoder_id=text_encoder_id, - transformer_id=transformer_id, - vae_id=vae_id, - text_encoder_dtype=text_encoder_dtype, - transformer_dtype=transformer_dtype, - vae_dtype=vae_dtype, - revision=revision, - cache_dir=cache_dir, - ) - - if condition_model_processors is None: - condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])] - if latent_model_processors is None: - latent_model_processors = [ - LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"]) - ] - - self.condition_model_processors = condition_model_processors - self.latent_model_processors = latent_model_processors - - @property - def _resolution_dim_keys(self): - return {"latents": (2, 3, 4)} - - def load_condition_models(self) -> Dict[str, torch.nn.Module]: - if self.tokenizer_id is not None: - tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - tokenizer = T5Tokenizer.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.text_encoder_id is not None: - text_encoder = AutoModel.from_pretrained( - self.text_encoder_id, - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - text_encoder = T5EncoderModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"tokenizer": tokenizer, "text_encoder": text_encoder} - - def load_latent_models(self) -> Dict[str, torch.nn.Module]: - if self.vae_id is not None: - vae = AutoencoderKLLTXVideo.from_pretrained( - self.vae_id, - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - vae = AutoencoderKLLTXVideo.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"vae": vae} - - def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: - if self.transformer_id is not None: - transformer = LTXVideoTransformer3DModel.from_pretrained( - self.transformer_id, - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - transformer = LTXVideoTransformer3DModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return {"transformer": transformer, "scheduler": scheduler} - - def load_pipeline( - self, - tokenizer: Optional[T5Tokenizer] = None, - text_encoder: Optional[T5EncoderModel] = None, - transformer: Optional[LTXVideoTransformer3DModel] = None, - vae: Optional[AutoencoderKLLTXVideo] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - training: bool = False, - **kwargs, - ) -> LTXPipeline: - components = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - } - components = get_non_null_items(components) - - pipe = LTXPipeline.from_pretrained( - self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir - ) - pipe.text_encoder.to(self.text_encoder_dtype) - pipe.vae.to(self.vae_dtype) - - if not training: - pipe.transformer.to(self.transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload() - - return pipe - - @torch.no_grad() - def prepare_conditions( - self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - caption: str, - max_sequence_length: int = 128, - **kwargs, - ) -> Dict[str, Any]: - conditions = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "caption": caption, - "max_sequence_length": max_sequence_length, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_conditions(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - @torch.no_grad() - def prepare_latents( - self, - vae: AutoencoderKLLTXVideo, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Dict[str, torch.Tensor]: - conditions = { - "vae": vae, - "image": image, - "video": video, - "generator": generator, - "compute_posterior": compute_posterior, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_latents(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - def forward( - self, - transformer: LTXVideoTransformer3DModel, - condition_model_conditions: Dict[str, torch.Tensor], - latent_model_conditions: Dict[str, torch.Tensor], - sigmas: torch.Tensor, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Tuple[torch.Tensor, ...]: - # TODO(aryan): make this configurable? Should it be? - first_frame_conditioning_p = 0.1 - min_first_frame_sigma = 0.25 - - if compute_posterior: - latents = latent_model_conditions.pop("latents") - else: - posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) - latents = posterior.sample(generator=generator) - del posterior - - latents_mean = latent_model_conditions.pop("latents_mean") - latents_std = latent_model_conditions.pop("latents_std") - - latents = self._normalize_latents(latents, latents_mean, latents_std) - noise = torch.zeros_like(latents).normal_(generator=generator) - - if random.random() < first_frame_conditioning_p: - # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value. - # Making as estimated guess, we limit the sigmas to be at least 0.2. - # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas - # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas). - first_frame_sigma = torch.rand_like(sigmas) * sigmas - first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma)) - - latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:] - noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma) - noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas) - noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2) - else: - noisy_latents = FF.flow_match_xt(latents, noise, sigmas) - - patch_size = self.transformer_config.patch_size - patch_size_t = self.transformer_config.patch_size_t - - latents = self._pack_latents(latents, patch_size, patch_size_t) - noise = self._pack_latents(noise, patch_size, patch_size_t) - noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t) - sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1) - timesteps = (sigmas * 1000.0).long() - - latent_model_conditions["hidden_states"] = noisy_latents.to(latents) - - # TODO(aryan): make this configurable - frame_rate = 25 - temporal_compression_ratio = 8 - vae_spatial_compression_ratio = 32 - latent_frame_rate = frame_rate / temporal_compression_ratio - - rope_interpolation_scale = [ - 1 / latent_frame_rate, - vae_spatial_compression_ratio, - vae_spatial_compression_ratio, - ] - - pred = transformer( - **latent_model_conditions, - **condition_model_conditions, - timestep=timesteps, - rope_interpolation_scale=rope_interpolation_scale, - return_dict=False, - )[0] - target = FF.flow_match_target(noise, latents) - - return pred, target, sigmas - - def validation( - self, - pipeline: LTXPipeline, - prompt: str, - image: Optional[Image] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - frame_rate: int = 25, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - **kwargs, - ) -> List[ArtifactType]: - if image is not None: - pipeline = LTXImageToVideoPipeline.from_pipe(pipeline) - - generation_kwargs = { - "prompt": prompt, - "image": image, - "height": height, - "width": width, - "num_frames": num_frames, - "frame_rate": frame_rate, - "num_inference_steps": num_inference_steps, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = get_non_null_items(generation_kwargs) - video = pipeline(**generation_kwargs).frames[0] - return [data.VideoArtifact(value=video)] - - def _save_lora_weights( - self, - directory: str, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - *args, - **kwargs, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - def _save_model( - self, - directory: str, - transformer: LTXVideoTransformer3DModel, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - with init_empty_weights(): - transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config) - transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) - transformer_copy.save_pretrained(os.path.join(directory, "transformer")) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - def apply_tensor_parallel( - self, - backend: ParallelBackendEnum, - device_mesh: torch.distributed.DeviceMesh, - transformer: LTXVideoTransformer3DModel, - **kwargs, - ) -> None: - if backend == ParallelBackendEnum.PTD: - _apply_tensor_parallel_ptd(device_mesh, transformer) - else: - raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification") - - @staticmethod - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) - latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents) - return latents - - @staticmethod - def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: - # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. - # The patch dimensions are then permuted and collapsed into the channel dimension of shape: - # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). - # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features - batch_size, num_channels, num_frames, height, width = latents.shape - post_patch_num_frames = num_frames // patch_size_t - post_patch_height = height // patch_size - post_patch_width = width // patch_size - latents = latents.reshape( - batch_size, - -1, - post_patch_num_frames, - patch_size_t, - post_patch_height, - patch_size, - post_patch_width, - patch_size, - ) - latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - return latents - - -def _apply_tensor_parallel_ptd( - device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel -) -> None: - from torch.distributed.tensor.parallel import parallelize_module - from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel - - transformer_plan = { - # ===== Condition embeddings ===== - # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), - # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), - # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), - # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), - # "caption_projection.linear_1": ColwiseParallel(), - # "caption_projection.linear_2": RowwiseParallel(), - # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), - # ===== ===== - } - - for block in transformer.transformer_blocks: - block_plan = {} - - # ===== Attention ===== - # 8 all-to-all, 3 all-reduce - # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) - # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) - # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) - # block_plan["attn1.norm_q"] = SequenceParallel() - # block_plan["attn1.norm_k"] = SequenceParallel() - # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) - # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) - # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) - # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) - # block_plan["attn2.norm_q"] = SequenceParallel() - # block_plan["attn2.norm_k"] = SequenceParallel() - # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) - # ===== ===== - - block_plan["ff.net.0.proj"] = ColwiseParallel() - block_plan["ff.net.2"] = RowwiseParallel() - - parallelize_module(block, device_mesh, block_plan) - - parallelize_module(transformer, device_mesh, transformer_plan) diff --git a/finetrainers/models/modeling_utils.py b/finetrainers/models/modeling_utils.py deleted file mode 100644 index b9adcf3bd528ef3205438044c54551ba954388fa..0000000000000000000000000000000000000000 --- a/finetrainers/models/modeling_utils.py +++ /dev/null @@ -1,289 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from diffusers import DiffusionPipeline -from diffusers.configuration_utils import FrozenDict -from PIL.Image import Image - -from ..logging import get_logger -from ..parallel import ParallelBackendEnum -from ..processors import ProcessorMixin -from ..typing import ArtifactType, SchedulerType, TokenizerType -from ..utils import resolve_component_cls - - -logger = get_logger() - -# TODO(aryan): we most likely don't need this. take a look after refactoring more -# fmt: off -IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"} -# fmt: on - - -class ModelSpecification: - r""" - The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides - loose structure about how to organize the code for training. The trainer implementations will - make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc. - """ - - def __init__( - self, - pretrained_model_name_or_path: Optional[str] = None, - tokenizer_id: Optional[str] = None, - tokenizer_2_id: Optional[str] = None, - tokenizer_3_id: Optional[str] = None, - text_encoder_id: Optional[str] = None, - text_encoder_2_id: Optional[str] = None, - text_encoder_3_id: Optional[str] = None, - transformer_id: Optional[str] = None, - vae_id: Optional[str] = None, - text_encoder_dtype: torch.dtype = torch.bfloat16, - text_encoder_2_dtype: torch.dtype = torch.bfloat16, - text_encoder_3_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: str = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - condition_model_processors: List[ProcessorMixin] = None, - latent_model_processors: List[ProcessorMixin] = None, - ) -> None: - self.pretrained_model_name_or_path = pretrained_model_name_or_path - self.tokenizer_id = tokenizer_id - self.tokenizer_2_id = tokenizer_2_id - self.tokenizer_3_id = tokenizer_3_id - self.text_encoder_id = text_encoder_id - self.text_encoder_2_id = text_encoder_2_id - self.text_encoder_3_id = text_encoder_3_id - self.transformer_id = transformer_id - self.vae_id = vae_id - self.text_encoder_dtype = text_encoder_dtype - self.text_encoder_2_dtype = text_encoder_2_dtype - self.text_encoder_3_dtype = text_encoder_3_dtype - self.transformer_dtype = transformer_dtype - self.vae_dtype = vae_dtype - self.revision = revision - self.cache_dir = cache_dir - self.condition_model_processors = condition_model_processors or [] - self.latent_model_processors = latent_model_processors or [] - - self.transformer_config: Dict[str, Any] = None - self.vae_config: Dict[str, Any] = None - - self._load_configs() - - # TODO(aryan): revisit how to do this better without user having to worry about it - @property - def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]: - raise NotImplementedError( - f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}" - ) - - def load_condition_models(self) -> Dict[str, torch.nn.Module]: - raise NotImplementedError( - f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}" - ) - - def load_latent_models(self) -> Dict[str, torch.nn.Module]: - raise NotImplementedError( - f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}" - ) - - def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]: - raise NotImplementedError( - f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}" - ) - - def load_pipeline( - self, - tokenizer: Optional[TokenizerType] = None, - tokenizer_2: Optional[TokenizerType] = None, - tokenizer_3: Optional[TokenizerType] = None, - text_encoder: Optional[torch.nn.Module] = None, - text_encoder_2: Optional[torch.nn.Module] = None, - text_encoder_3: Optional[torch.nn.Module] = None, - transformer: Optional[torch.nn.Module] = None, - vae: Optional[torch.nn.Module] = None, - scheduler: Optional[SchedulerType] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - training: bool = False, - **kwargs, - ) -> DiffusionPipeline: - raise NotImplementedError( - f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}" - ) - - def prepare_conditions(self, **kwargs) -> Dict[str, Any]: - for processor in self.condition_model_processors: - result = processor(**kwargs) - result_keys = set(result.keys()) - repeat_keys = result_keys.intersection(kwargs.keys()) - if repeat_keys: - logger.warning( - f"Processor {processor.__class__.__name__} returned keys that already exist in " - f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " - f"be intended. Please rename the keys in the processor to avoid conflicts." - ) - kwargs.update(result) - return kwargs - - def prepare_latents(self, **kwargs) -> Dict[str, Any]: - for processor in self.latent_model_processors: - result = processor(**kwargs) - result_keys = set(result.keys()) - repeat_keys = result_keys.intersection(kwargs.keys()) - if repeat_keys: - logger.warning( - f"Processor {processor.__class__.__name__} returned keys that already exist in " - f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " - f"be intended. Please rename the keys in the processor to avoid conflicts." - ) - kwargs.update(result) - return kwargs - - def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: - keys = list(data[0].keys()) - collated_data = {} - for key in keys: - if key in IGNORE_KEYS_FOR_COLLATION: - collated_data[key] = data[0][key] - continue - collated_d = [d[key] for d in data] - if isinstance(collated_d[0], torch.Tensor): - collated_d = torch.cat(collated_d) - collated_data[key] = collated_d - return collated_data - - def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: - keys = list(data[0].keys()) - collated_data = {} - for key in keys: - if key in IGNORE_KEYS_FOR_COLLATION: - collated_data[key] = data[0][key] - continue - collated_d = [d[key] for d in data] - # TODO(aryan): Support multi-resolution collation - if isinstance(collated_d[0], torch.Tensor): - collated_d = torch.cat(collated_d) - collated_data[key] = collated_d - return collated_data - - def forward( - self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs - ) -> Dict[str, torch.Tensor]: - raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}") - - def validation( - self, - pipeline: DiffusionPipeline, - prompt: Optional[str] = None, - image: Optional[Image] = None, - video: Optional[List[Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - frame_rate: Optional[int] = None, - generator: Optional[torch.Generator] = None, - ) -> List[ArtifactType]: - raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}") - - def _save_lora_weights( - self, - directory: str, - transformer: torch.nn.Module, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - r""" - Save the lora state dicts of the model to the given directory. - - This API is not backwards compatible and will be changed in near future. - """ - raise NotImplementedError( - f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}" - ) - - def _save_model( - self, - directory: str, - transformer: torch.nn.Module, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - r""" - Save the state dicts to the given directory. - - This API is not backwards compatible and will be changed in near future. - """ - raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}") - - def apply_tensor_parallel( - self, - backend: ParallelBackendEnum, - device_mesh: torch.distributed.DeviceMesh, - text_encoder: torch.nn.Module, - text_encoder_2: torch.nn.Module, - text_encoder_3: torch.nn.Module, - transformer: torch.nn.Module, - vae: torch.nn.Module, - ) -> None: - raise NotImplementedError( - f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}" - ) - - def _load_configs(self) -> None: - self._load_transformer_config() - self._load_vae_config() - - def _load_transformer_config(self) -> None: - if self.transformer_id is not None: - transformer_cls = resolve_component_cls( - self.transformer_id, - component_name="_class_name", - filename="config.json", - revision=self.revision, - cache_dir=self.cache_dir, - ) - self.transformer_config = transformer_cls.load_config( - self.transformer_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - transformer_cls = resolve_component_cls( - self.pretrained_model_name_or_path, - component_name="transformer", - filename="model_index.json", - revision=self.revision, - cache_dir=self.cache_dir, - ) - self.transformer_config = transformer_cls.load_config( - self.pretrained_model_name_or_path, - subfolder="transformer", - revision=self.revision, - cache_dir=self.cache_dir, - ) - self.transformer_config = FrozenDict(**self.transformer_config) - - def _load_vae_config(self) -> None: - if self.vae_id is not None: - vae_cls = resolve_component_cls( - self.vae_id, - component_name="_class_name", - filename="config.json", - revision=self.revision, - cache_dir=self.cache_dir, - ) - self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir) - else: - vae_cls = resolve_component_cls( - self.pretrained_model_name_or_path, - component_name="vae", - filename="model_index.json", - revision=self.revision, - cache_dir=self.cache_dir, - ) - self.vae_config = vae_cls.load_config( - self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir - ) - self.vae_config = FrozenDict(**self.vae_config) diff --git a/finetrainers/models/utils.py b/finetrainers/models/utils.py deleted file mode 100644 index aeda1e4379dfc6ab1d7baba1807f6e0ac71d779b..0000000000000000000000000000000000000000 --- a/finetrainers/models/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Optional, Tuple - -import numpy as np -import torch -from diffusers.utils.torch_utils import randn_tensor - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1): - # Note: _dim is the new argument added here after copying from diffusers - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) - - def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: - # make sure sample is on the same device as the parameters and has same dtype - sample = randn_tensor( - self.mean.shape, - generator=generator, - device=self.parameters.device, - dtype=self.parameters.dtype, - ) - x = self.mean + self.std * sample - return x - - def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self) -> torch.Tensor: - return self.mean diff --git a/finetrainers/models/wan/__init__.py b/finetrainers/models/wan/__init__.py deleted file mode 100644 index 2bfeae2994e6b83b8fcb1337602e5cb39c73fdc7..0000000000000000000000000000000000000000 --- a/finetrainers/models/wan/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_specification import WanModelSpecification diff --git a/finetrainers/models/wan/base_specification.py b/finetrainers/models/wan/base_specification.py deleted file mode 100644 index c908c43d44d386f4e18f632fd058a72cd5bcce5d..0000000000000000000000000000000000000000 --- a/finetrainers/models/wan/base_specification.py +++ /dev/null @@ -1,393 +0,0 @@ -import os -from typing import Any, Dict, List, Optional, Tuple - -import torch -from accelerate import init_empty_weights -from diffusers import ( - AutoencoderKLWan, - FlowMatchEulerDiscreteScheduler, - WanImageToVideoPipeline, - WanPipeline, - WanTransformer3DModel, -) -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from PIL.Image import Image -from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel - -from ... import data -from ... import functional as FF -from ...logging import get_logger -from ...processors import ProcessorMixin, T5Processor -from ...typing import ArtifactType, SchedulerType -from ...utils import get_non_null_items -from ..modeling_utils import ModelSpecification - - -logger = get_logger() - - -class WanLatentEncodeProcessor(ProcessorMixin): - r""" - Processor to encode image/video into latents using the Wan VAE. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor returns. The outputs are in the following order: - - latents: The latents of the input image/video. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - self.output_names = output_names - assert len(self.output_names) == 3 - - def forward( - self, - vae: AutoencoderKLWan, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - ) -> Dict[str, torch.Tensor]: - device = vae.device - dtype = vae.dtype - - if image is not None: - video = image.unsqueeze(1) - - assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" - video = video.to(device=device, dtype=vae.dtype) - video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] - - if compute_posterior: - latents = vae.encode(video).latent_dist.sample(generator=generator) - latents = latents.to(dtype=dtype) - else: - # TODO(aryan): refactor in diffusers to have use_slicing attribute - # if vae.use_slicing and video.shape[0] > 1: - # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] - # moments = torch.cat(encoded_slices) - # else: - # moments = vae._encode(video) - moments = vae._encode(video) - latents = moments.to(dtype=dtype) - - latents_mean = torch.tensor(vae.config.latents_mean) - latents_std = 1.0 / torch.tensor(vae.config.latents_std) - - return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std} - - -class WanModelSpecification(ModelSpecification): - def __init__( - self, - pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", - tokenizer_id: Optional[str] = None, - text_encoder_id: Optional[str] = None, - transformer_id: Optional[str] = None, - vae_id: Optional[str] = None, - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - condition_model_processors: List[ProcessorMixin] = None, - latent_model_processors: List[ProcessorMixin] = None, - **kwargs, - ) -> None: - super().__init__( - pretrained_model_name_or_path=pretrained_model_name_or_path, - tokenizer_id=tokenizer_id, - text_encoder_id=text_encoder_id, - transformer_id=transformer_id, - vae_id=vae_id, - text_encoder_dtype=text_encoder_dtype, - transformer_dtype=transformer_dtype, - vae_dtype=vae_dtype, - revision=revision, - cache_dir=cache_dir, - ) - - if condition_model_processors is None: - condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])] - if latent_model_processors is None: - latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])] - - self.condition_model_processors = condition_model_processors - self.latent_model_processors = latent_model_processors - - @property - def _resolution_dim_keys(self): - return {"latents": (2, 3, 4)} - - def load_condition_models(self) -> Dict[str, torch.nn.Module]: - if self.tokenizer_id is not None: - tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=self.revision, - cache_dir=self.cache_dir, - ) - - if self.text_encoder_id is not None: - text_encoder = AutoModel.from_pretrained( - self.text_encoder_id, - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - text_encoder = UMT5EncoderModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=self.text_encoder_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"tokenizer": tokenizer, "text_encoder": text_encoder} - - def load_latent_models(self) -> Dict[str, torch.nn.Module]: - if self.vae_id is not None: - vae = AutoencoderKLWan.from_pretrained( - self.vae_id, - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - vae = AutoencoderKLWan.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=self.vae_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - return {"vae": vae} - - def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: - if self.transformer_id is not None: - transformer = WanTransformer3DModel.from_pretrained( - self.transformer_id, - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - else: - transformer = WanTransformer3DModel.from_pretrained( - self.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=self.transformer_dtype, - revision=self.revision, - cache_dir=self.cache_dir, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return {"transformer": transformer, "scheduler": scheduler} - - def load_pipeline( - self, - tokenizer: Optional[AutoTokenizer] = None, - text_encoder: Optional[UMT5EncoderModel] = None, - transformer: Optional[WanTransformer3DModel] = None, - vae: Optional[AutoencoderKLWan] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - training: bool = False, - **kwargs, - ) -> WanPipeline: - components = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - } - components = get_non_null_items(components) - - pipe = WanPipeline.from_pretrained( - self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir - ) - pipe.text_encoder.to(self.text_encoder_dtype) - pipe.vae.to(self.vae_dtype) - - if not training: - pipe.transformer.to(self.transformer_dtype) - - # TODO(aryan): add support in diffusers - # if enable_slicing: - # pipe.vae.enable_slicing() - # if enable_tiling: - # pipe.vae.enable_tiling() - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload() - - return pipe - - @torch.no_grad() - def prepare_conditions( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - caption: str, - max_sequence_length: int = 512, - **kwargs, - ) -> Dict[str, Any]: - conditions = { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "caption": caption, - "max_sequence_length": max_sequence_length, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_conditions(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - conditions.pop("prompt_attention_mask", None) - return conditions - - @torch.no_grad() - def prepare_latents( - self, - vae: AutoencoderKLWan, - image: Optional[torch.Tensor] = None, - video: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Dict[str, torch.Tensor]: - conditions = { - "vae": vae, - "image": image, - "video": video, - "generator": generator, - # We must force this to False because the latent normalization should be done before - # the posterior is computed. The VAE does not handle this any more: - # https://github.com/huggingface/diffusers/pull/10998 - "compute_posterior": False, - **kwargs, - } - input_keys = set(conditions.keys()) - conditions = super().prepare_latents(**conditions) - conditions = {k: v for k, v in conditions.items() if k not in input_keys} - return conditions - - def forward( - self, - transformer: WanTransformer3DModel, - condition_model_conditions: Dict[str, torch.Tensor], - latent_model_conditions: Dict[str, torch.Tensor], - sigmas: torch.Tensor, - generator: Optional[torch.Generator] = None, - compute_posterior: bool = True, - **kwargs, - ) -> Tuple[torch.Tensor, ...]: - compute_posterior = False # See explanation in prepare_latents - if compute_posterior: - latents = latent_model_conditions.pop("latents") - else: - latents = latent_model_conditions.pop("latents") - latents_mean = latent_model_conditions.pop("latents_mean") - latents_std = latent_model_conditions.pop("latents_std") - - mu, logvar = torch.chunk(latents, 2, dim=1) - mu = self._normalize_latents(mu, latents_mean, latents_std) - logvar = self._normalize_latents(logvar, latents_mean, latents_std) - latents = torch.cat([mu, logvar], dim=1) - - posterior = DiagonalGaussianDistribution(latents) - latents = posterior.sample(generator=generator) - del posterior - - noise = torch.zeros_like(latents).normal_(generator=generator) - noisy_latents = FF.flow_match_xt(latents, noise, sigmas) - timesteps = (sigmas.flatten() * 1000.0).long() - - latent_model_conditions["hidden_states"] = noisy_latents.to(latents) - - pred = transformer( - **latent_model_conditions, - **condition_model_conditions, - timestep=timesteps, - return_dict=False, - )[0] - target = FF.flow_match_target(noise, latents) - - return pred, target, sigmas - - def validation( - self, - pipeline: WanPipeline, - prompt: str, - image: Optional[Image] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - generator: Optional[torch.Generator] = None, - **kwargs, - ) -> List[ArtifactType]: - if image is not None: - pipeline = WanImageToVideoPipeline.from_pipe(pipeline) - - generation_kwargs = { - "prompt": prompt, - "image": image, - "height": height, - "width": width, - "num_frames": num_frames, - "num_inference_steps": num_inference_steps, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = get_non_null_items(generation_kwargs) - video = pipeline(**generation_kwargs).frames[0] - return [data.VideoArtifact(value=video)] - - def _save_lora_weights( - self, - directory: str, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - *args, - **kwargs, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - def _save_model( - self, - directory: str, - transformer: WanTransformer3DModel, - transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, - scheduler: Optional[SchedulerType] = None, - ) -> None: - # TODO(aryan): this needs refactoring - if transformer_state_dict is not None: - with init_empty_weights(): - transformer_copy = WanTransformer3DModel.from_config(transformer.config) - transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) - transformer_copy.save_pretrained(os.path.join(directory, "transformer")) - if scheduler is not None: - scheduler.save_pretrained(os.path.join(directory, "scheduler")) - - @staticmethod - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor - ) -> torch.Tensor: - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) - latents = ((latents.float() - latents_mean) * latents_std).to(latents) - return latents diff --git a/finetrainers/optimizer.py b/finetrainers/optimizer.py deleted file mode 100644 index 57da28e9377f2bf82b5307fae83338ad0b9ec385..0000000000000000000000000000000000000000 --- a/finetrainers/optimizer.py +++ /dev/null @@ -1,449 +0,0 @@ -import functools -import math -from typing import Any, Callable, Dict, List, Optional, Type, Union - -import torch -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_optimizer_state_dict, - set_optimizer_state_dict, -) -from torch.distributed.checkpoint.stateful import Stateful - -from .parallel import ParallelBackendEnum -from .utils.import_utils import is_bitsandbytes_available - - -class OptimizerWrapper(Stateful): - r""" - Optimizer wrapper that: - - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages - - saves/loading optimizer state_dict at checkpoint - """ - - def __init__( - self, - model_parts: List[torch.nn.Module], - optimizer_cls: Type[torch.optim.Optimizer], - optimizer_kwargs: Dict[str, Any], - ) -> None: - self.optimizer_cls = optimizer_cls - self.optimizer_kwargs = optimizer_kwargs - - self.optimizers = [] - self.model_parts = model_parts - - for model in self.model_parts: - optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) - self.optimizers.append(optimizer) - - def step(self) -> None: - for optimizer in self.optimizers: - optimizer.step() - - def zero_grad(self) -> None: - for optimizer in self.optimizers: - optimizer.zero_grad() - - def state_dict(self) -> Dict[str, Any]: - func = functools.partial( - get_optimizer_state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - func = functools.partial( - set_optimizer_state_dict, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - list(map(func, self.model_parts, self.optimizers)) - - -class SchedulerWrapper: - def __init__( - self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int - ) -> None: - self.schedulers = [] - for optimizer in optimizers: - self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)) - - def step(self) -> None: - for scheduler in self.schedulers: - scheduler.step() - - def get_last_lr(self) -> List[float]: - # TODO(aryan): look into this later. Currently calling it leads to NCCL hang????? - return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)} - - def get_lr_scheduler_state(self) -> Dict[str, Any]: - state_dict = {} - if len(self.schedulers) == 1: - state_dict["lr_scheduler"] = self.schedulers[0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(self.schedulers): - state_dict[f"lr_scheduler_{idx}"] = lr_scheduler - return state_dict - - -def get_optimizer( - parallel_backend: ParallelBackendEnum, - name: str, - model_parts: List[torch.nn.Module], - learning_rate: float = 1e-3, - beta1: float = 0.9, - beta2: float = 0.95, - beta3: float = 0.999, - epsilon: float = 1e-8, - weight_decay: float = 1e-4, - fused: bool = False, -) -> Union[torch.optim.Optimizer, OptimizerWrapper]: - name = name.lower() - - _raise_errors_if_packages_not_available(name) - - if name == "adam": - optimizer_cls = torch.optim.Adam - optimizer_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - "fused": fused, - } - elif name == "adamw": - optimizer_cls = torch.optim.AdamW - optimizer_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - "fused": fused, - } - elif name == "adam-bnb": - from bitsandbytes.optim import Adam - - optimizer_cls = Adam - optimizer_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - elif name == "adamw-bnb": - from bitsandbytes.optim import AdamW - - optimizer_cls = AdamW - optimizer_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - elif name == "adam-bnb-8bit": - from bitsandbytes.optim import Adam8bit - - optimizer_cls = Adam8bit - optimizer_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - elif name == "adamw-bnb-8bit": - from bitsandbytes.optim import AdamW8bit - - optimizer_cls = AdamW8bit - optimizer_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - - # TODO(aryan): handle bitsandbytes and torchao - else: - raise ValueError(f"Unsupported optimizer: {name}") - - if parallel_backend == ParallelBackendEnum.ACCELERATE: - return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs) - elif parallel_backend == ParallelBackendEnum.PTD: - return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs) - - -def get_optimizer_accelerate( - model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] -) -> torch.optim.Optimizer: - params = [param for model in model_parts for param in model.parameters() if param.requires_grad] - optimizer = optimizer_cls(params, **optimizer_kwargs) - return optimizer - - -def get_optimizer_ptd( - model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] -) -> OptimizerWrapper: - return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs) - - -def get_lr_scheduler( - parallel_backend: ParallelBackendEnum, - name: str, - optimizer: Union[torch.optim.Optimizer, OptimizerWrapper], - step_rules: Optional[str] = None, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - num_cycles: int = 1, - power: float = 1.0, - lr_init: float = 1e-3, - lr_end: float = 1e-7, - last_epoch: int = -1, -) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]: - name = name.lower() - if name == "constant": - scheduler_lambda_fn = get_constant_schedule() - elif name == "constant_with_warmup": - scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps) - elif name == "piecewise_constant": - scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules) - elif name == "linear": - scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps) - elif name == "cosine": - scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles) - elif name == "cosine_with_restarts": - scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup( - num_warmup_steps, num_training_steps, num_cycles - ) - elif name == "polynomial": - scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup( - num_warmup_steps, num_training_steps, lr_init, lr_end, power - ) - else: - raise ValueError(f"Unsupported scheduler: {name}") - - if parallel_backend == ParallelBackendEnum.ACCELERATE: - return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch) - elif parallel_backend == ParallelBackendEnum.PTD: - return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch) - - -def get_lr_scheduler_accelerate( - optimizer: torch.optim.Optimizer, - scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], - last_epoch: int = -1, -) -> torch.optim.lr_scheduler.LambdaLR: - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch) - return scheduler - - -def get_lr_scheduler_ptd( - optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1 -) -> SchedulerWrapper: - return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch) - - -# ============================== -# Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py -# ============================== - - -def get_constant_schedule() -> Callable[[int], float]: - r""" - Create a schedule with a constant learning rate, using the learning rate set in optimizer. - """ - - def lr_lambda(current_step: int): - return 1.0 - - return lr_lambda - - -def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]: - r""" - Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate - increases linearly between 0 and the initial lr set in the optimizer. - - Args: - num_warmup_steps (`int`): - The number of steps for the warmup phase. - """ - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1.0, num_warmup_steps)) - return 1.0 - - return lr_lambda - - -def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]: - r""" - Create a schedule with a constant learning rate, using the learning rate set in optimizer. - - Args: - step_rules (`string`): - The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate - if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 - steps and multiple 0.005 for the other steps. - """ - - rules_dict = {} - rule_list = step_rules.split(",") - for rule_str in rule_list[:-1]: - value_str, steps_str = rule_str.split(":") - steps = int(steps_str) - value = float(value_str) - rules_dict[steps] = value - last_lr_multiple = float(rule_list[-1]) - - def create_rules_function(rules_dict, last_lr_multiple): - def rule_func(steps: int) -> float: - sorted_steps = sorted(rules_dict.keys()) - for i, sorted_step in enumerate(sorted_steps): - if steps < sorted_step: - return rules_dict[sorted_steps[i]] - return last_lr_multiple - - return rule_func - - rules_func = create_rules_function(rules_dict, last_lr_multiple) - return rules_func - - -def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]: - r""" - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - """ - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) - ) - - return lr_lambda - - -def get_cosine_schedule_with_warmup( - num_warmup_steps: int, - num_training_steps: int, - num_cycles: float = 0.5, -) -> Callable[[int], float]: - r""" - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - - Args: - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_periods (`float`, *optional*, defaults to 0.5): - The number of periods of the cosine function in a schedule (the default is to just decrease from the max - value to 0 following a half-cosine). - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) - - return lr_lambda - - -def get_cosine_with_hard_restarts_schedule_with_warmup( - num_warmup_steps: int, - num_training_steps: int, - num_cycles: int = 1, -) -> Callable[[int], float]: - r""" - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases - linearly between 0 and the initial lr set in the optimizer. - - Args: - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_cycles (`int`, *optional*, defaults to 1): - The number of hard restarts to use. - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - if progress >= 1.0: - return 0.0 - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) - - return lr_lambda - - -def get_polynomial_decay_schedule_with_warmup( - num_warmup_steps: int, - num_training_steps: int, - lr_init: float, - lr_end: float = 1e-7, - power: float = 1.0, -) -> Callable[[int], float]: - r""" - Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the - optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the - initial lr set in the optimizer. - - Args: - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - lr_end (`float`, *optional*, defaults to 1e-7): - The end LR. - power (`float`, *optional*, defaults to 1.0): - Power factor. - - Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at - https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 - """ - - if not (lr_init > lr_end): - raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - elif current_step > num_training_steps: - return lr_end / lr_init # as LambdaLR multiplies by lr_init - else: - lr_range = lr_init - lr_end - decay_steps = num_training_steps - num_warmup_steps - pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - return decay / lr_init # as LambdaLR multiplies by lr_init - - return lr_lambda - - -def _raise_errors_if_packages_not_available(name: str) -> None: - name_split = name.split("-") - if len(name_split) < 2: - return - package_name = name_split[1] - if package_name == "bnb": - if not is_bitsandbytes_available(): - raise ImportError( - f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer." - ) diff --git a/finetrainers/parallel/__init__.py b/finetrainers/parallel/__init__.py deleted file mode 100644 index 5fbce75fadfc33312f6d655fc581ac90cfb6c577..0000000000000000000000000000000000000000 --- a/finetrainers/parallel/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from enum import Enum -from typing import Union - -from .accelerate import AccelerateParallelBackend -from .ptd import PytorchDTensorParallelBackend -from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean - - -ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend] - - -class ParallelBackendEnum(str, Enum): - ACCELERATE = "accelerate" - PTD = "ptd" - - -def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType: - if backend == ParallelBackendEnum.ACCELERATE: - return AccelerateParallelBackend - if backend == ParallelBackendEnum.PTD: - return PytorchDTensorParallelBackend - raise ValueError(f"Unknown parallel backend: {backend}") diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py deleted file mode 100644 index 9a523321f0c333af54949cb2274fb2d60cf014ff..0000000000000000000000000000000000000000 --- a/finetrainers/parallel/accelerate.py +++ /dev/null @@ -1,218 +0,0 @@ -import datetime -import pathlib -from typing import Optional - -import torch -from diffusers.utils import is_accelerate_available - -from ..logging import get_logger -from ..utils import get_device_info -from .base import BaseParallelBackend -from .utils import apply_ddp_accelerate - - -if not is_accelerate_available(): - raise ImportError( - "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend." - ) - -from accelerate import Accelerator -from accelerate.data_loader import DataLoader -from accelerate.utils import ( - DataLoaderConfiguration, - DistributedDataParallelKwargs, - InitProcessGroupKwargs, - ProjectConfiguration, -) - - -logger = get_logger() -_device_type, _device_module = get_device_info() - - -class AccelerateParallelBackend(BaseParallelBackend): - def __init__( - self, - world_size: int, - pp_degree: int = 1, - dp_degree: int = 1, - dp_shards: int = -1, - cp_degree: int = 1, - tp_degree: int = 1, - backend: str = "nccl", - timeout: int = 180, - logging_dir: Optional[str] = None, - output_dir: Optional[str] = None, - gradient_accumulation_steps: Optional[int] = None, - ) -> None: - super().__init__() - - self._world_size = world_size - self._pp_degree = pp_degree - self._dp_degree = dp_degree - self._dp_shards = dp_shards - self._cp_degree = cp_degree - self._tp_degree = tp_degree - self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None - self._logging_dir = ( - self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None - ) - self._backend = backend - self._timeout = timeout - self._gradient_accumulation_steps = gradient_accumulation_steps - - if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1: - raise ValueError( - "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment." - ) - if dp_degree != world_size: - raise ValueError("Data parallel degree must be equal to world size.") - - self._accelerator: Accelerator = None - self._mesh: torch.distributed.DeviceMesh = None - - def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: - project_config = None - ddp_kwargs = None - init_process_group_kwargs = None - if self._accelerator is None: - project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) - dataloader_config = DataLoaderConfiguration( - split_batches=False, dispatch_batches=False, use_stateful_dataloader=True - ) - init_process_group_kwargs = InitProcessGroupKwargs( - backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) - ) - self._accelerator, model = apply_ddp_accelerate( - model, - project_config, - ddp_kwargs, - init_process_group_kwargs, - dataloader_config, - self._gradient_accumulation_steps, - accelerator=self._accelerator, - ) - logger.debug("Applied AccelerateParallel::apply_ddp to model.") - return model - - def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: - logger.debug("AccelerateParallelBackend::prepare_dataset completed!") - return dataset - - def prepare_dataloader( - self, - dataset: torch.utils.data.IterableDataset, - batch_size: int = 1, - num_workers: int = 0, - pin_memory: bool = False, - ) -> DataLoader: - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory - ) - dataloader = self._accelerator.prepare_data_loader(dataloader) - logger.debug("AccelerateParallelBackend::prepare_dataloader completed!") - return dataloader - - def prepare_optimizer(self, optimizer, lr_scheduler): - optimizer = self._accelerator.prepare_optimizer(optimizer) - lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler) - return optimizer, lr_scheduler - - def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: - def _get_mesh(): - if name is None: - return self._mesh - try: - return self._mesh[name] - except (KeyError, RuntimeError): - return self._mesh - - if self._mesh is not None: - return _get_mesh() - - mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)] - mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] - names = [x[0] for x in mesh_list] - degrees = [x[1] for x in mesh_list] - mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) - - dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] - - if self.data_replication_enabled: - dp_mesh_names.append("dp_replicate") - dp_cp_mesh_names.append("dp_replicate") - if self.data_sharding_enabled: - dp_mesh_names.append("dp_shard") - dp_cp_mesh_names.append("dp_shard") - dp_shard_cp_mesh_names.append("dp_shard") - if self.context_parallel_enabled: - dp_cp_mesh_names.append("cp") - dp_shard_cp_mesh_names.append("cp") - - if len(dp_mesh_names) > 0: - mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") - if len(dp_cp_mesh_names) > 0: - mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") - if len(dp_shard_cp_mesh_names) > 0: - mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") - - logger.debug(f"Device mesh: {mesh}") - self._mesh = mesh - return _get_mesh() - - @property - def world_size(self): - return self._accelerator.num_processes - - @property - def rank(self): - return self._accelerator.process_index - - @property - def local_rank(self): - return self._accelerator.local_process_index - - @property - def is_main_process(self): - r"""Returns `True` if the current process is the main process on the master node.""" - return self._accelerator.is_main_process - - @property - def is_local_main_process(self): - r"""Returns `True` if the current process is the main process on local node.""" - return self._accelerator.is_local_main_process - - @property - def device(self): - return self._accelerator.device - - def wait_for_everyone(self): - self._accelerator.wait_for_everyone() - - def destroy(self): - self._accelerator.end_training() - - @property - def pipeline_parallel_enabled(self): - return self._pp_degree > 1 - - @property - def data_parallel_enabled(self): - return self._dp_degree > 1 or self._dp_shards > 1 - - @property - def data_replication_enabled(self): - return self._dp_degree > 1 - - @property - def data_sharding_enabled(self): - return self._dp_shards > 1 - - @property - def context_parallel_enabled(self): - return self._cp_degree > 1 - - @property - def tensor_parallel_enabled(self): - return self._tp_degree > 1 diff --git a/finetrainers/parallel/base.py b/finetrainers/parallel/base.py deleted file mode 100644 index eb982ca252fdcd03b644bf5a3215857fada7119c..0000000000000000000000000000000000000000 --- a/finetrainers/parallel/base.py +++ /dev/null @@ -1,96 +0,0 @@ -from contextlib import contextmanager -from typing import Any, Dict, List, Optional - -import torch - -from ..trackers import TrackerType, initialize_trackers - - -class BaseParallelBackend: - r""" - Base class that contains properties and methods that should be implemented by different parallel backends. - """ - - def apply_ddp(self, *args, **kwargs) -> torch.nn.Module: - raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.") - - def prepare_dataset(self, *args, **kwargs) -> Any: - raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.") - - def prepare_dataloader(self, *args, **kwargs) -> Any: - raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.") - - def prepare_optimizer(self, *args, **kwargs) -> Any: - raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.") - - def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: - raise NotImplementedError("Method `get_mesh` must be implemented by subclass.") - - def initialize_trackers( - self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str - ) -> TrackerType: - self.tracker = None - if self.is_main_process: - self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir) - - def log(self, metrics: Dict[str, Any], step: int) -> None: - if self.is_main_process: - self.tracker.log(metrics, step) - - def wait_for_everyone(self): - raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.") - - @contextmanager - def main_process_first(self): - raise NotImplementedError("Method `main_process_first` must be implemented by subclass.") - - def destroy(self): - raise NotImplementedError("Method `destroy` must be implemented by subclass.") - - @property - def world_size(self): - raise NotImplementedError("Method `world_size` must be implemented by subclass.") - - @property - def rank(self): - raise NotImplementedError("Method `rank` must be implemented by subclass.") - - @property - def local_rank(self): - raise NotImplementedError("Method `local_rank` must be implemented by subclass.") - - @property - def is_main_process(self): - raise NotImplementedError("Method `is_main_process` must be implemented by subclass.") - - @property - def is_local_main_process(self): - raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.") - - @property - def device(self): - raise NotImplementedError("Method `device` must be implemented by subclass.") - - @property - def pipeline_parallel_enabled(self): - raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.") - - @property - def data_parallel_enabled(self): - raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.") - - @property - def data_replication_enabled(self): - raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.") - - @property - def data_sharding_enabled(self): - raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.") - - @property - def context_parallel_enabled(self): - raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.") - - @property - def tensor_parallel_enabled(self): - raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.") diff --git a/finetrainers/parallel/deepspeed.py b/finetrainers/parallel/deepspeed.py deleted file mode 100644 index 8b9f54d66ec1941ffc44d6239b305cc397ce61d4..0000000000000000000000000000000000000000 --- a/finetrainers/parallel/deepspeed.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base import BaseParallelBackend - - -class DeepspeedParallelBackend(BaseParallelBackend): - def __init__(self): - # TODO(aryan) - raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.") diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py deleted file mode 100644 index 352273b4eff7f4cb21424962748a84ff29d96426..0000000000000000000000000000000000000000 --- a/finetrainers/parallel/ptd.py +++ /dev/null @@ -1,228 +0,0 @@ -import datetime -import os -import pathlib -from typing import Optional - -import datasets.distributed -import torch - -from ..data import DPDataLoader -from ..logging import get_logger -from ..utils import get_device_info -from .base import BaseParallelBackend -from .utils import apply_ddp_ptd - - -_device_type, _device_module = get_device_info() -logger = get_logger() - - -class PytorchDTensorParallelBackend(BaseParallelBackend): - def __init__( - self, - world_size: int, - pp_degree: int = 1, - dp_degree: int = 1, - dp_shards: int = -1, - cp_degree: int = 1, - tp_degree: int = 1, - backend: str = "nccl", - timeout: int = 180, - logging_dir: Optional[str] = None, - output_dir: Optional[str] = None, - gradient_accumulation_steps: Optional[int] = None, - ) -> None: - super().__init__() - - self._world_size = world_size - self._pp_degree = pp_degree - self._dp_degree = dp_degree - self._dp_shards = dp_shards - self._cp_degree = cp_degree - self._tp_degree = tp_degree - self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None - self._logging_dir = ( - self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None - ) - self._backend = backend - self._timeout = timeout - - for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]: - if degree < 1: - raise ValueError(f"Parallel degree must be at least 1, got {degree}.") - - if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size: - raise ValueError( - f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards." - ) - - torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)) - _device_module.set_device(self.local_rank) - - logger.info( - f"Initialized parallel state with:\n" - f" - World size: {world_size}\n" - f" - Pipeline parallel degree: {pp_degree}\n" - f" - Data parallel degree: {dp_degree}\n" - f" - Context parallel degree: {cp_degree}\n" - f" - Tensor parallel degree: {tp_degree}\n" - f" - Data parallel shards: {dp_shards}\n" - ) - - self._mesh: torch.distributed.DeviceMesh = None - - def apply_ddp( - self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None - ) -> torch.nn.Module: - if device_mesh is None: - device_mesh = self.get_mesh() - apply_ddp_ptd(model, device_mesh) - logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.") - return model - - def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: - dp_mesh = self.get_mesh("dp_replicate") - if dp_mesh is None: - dp_mesh = self.get_mesh() - if self.world_size > 1: - dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() - else: - dp_local_rank, dp_world_size = 0, 1 - dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) - logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!") - return dataset - - def prepare_dataloader( - self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool - ) -> DPDataLoader: - dp_mesh = self.get_mesh("dp_replicate") - if dp_mesh is None: - dp_mesh = self.get_mesh() - if self.world_size > 1: - dp_local_rank = dp_mesh.get_local_rank() - else: - dp_local_rank = 0 - dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers) - logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!") - return dataloader - - def prepare_optimizer(self, optimizer, lr_scheduler): - logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!") - return optimizer, lr_scheduler - - def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: - def _get_mesh(): - if name is None: - return self._mesh - try: - return self._mesh[name] - except (KeyError, RuntimeError): - if self._mesh.ndim == 0: - return None - return self._mesh - - if self._mesh is not None: - return _get_mesh() - - mesh_list = [ - ("pp", self._pp_degree), - ("dp_replicate", self._dp_degree), - ("dp_shard", self._dp_shards), - ("cp", self._cp_degree), - ("tp", self._tp_degree), - ] - mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] - names = [x[0] for x in mesh_list] - degrees = [x[1] for x in mesh_list] - mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) - - dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] - - if self.data_replication_enabled: - dp_mesh_names.append("dp_replicate") - dp_cp_mesh_names.append("dp_replicate") - if self.data_sharding_enabled: - dp_mesh_names.append("dp_shard") - dp_cp_mesh_names.append("dp_shard") - dp_shard_cp_mesh_names.append("dp_shard") - if self.context_parallel_enabled: - dp_cp_mesh_names.append("cp") - dp_shard_cp_mesh_names.append("cp") - - if len(dp_mesh_names) > 0: - mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") - if len(dp_cp_mesh_names) > 0: - mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") - if len(dp_shard_cp_mesh_names) > 0: - mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") - - logger.debug(f"Device mesh: {mesh}") - self._mesh = mesh - return _get_mesh() - - @property - def world_size(self): - return torch.distributed.get_world_size() - - @property - def rank(self): - return torch.distributed.get_rank() - - @property - def local_rank(self): - return int(os.environ.get("LOCAL_RANK", 0)) - - @property - def is_main_process(self): - r"""Returns `True` if the current process is the main process on the master node.""" - return self.rank == 0 - - @property - def is_local_main_process(self): - r"""Returns `True` if the current process is the main process on local node.""" - return self.local_rank == 0 - - @property - def device(self): - return torch.device(_device_type, self.local_rank) - - def wait_for_everyone(self): - return torch.distributed.barrier() - - # @contextmanager - # def main_process_first(self): - # if self.is_main_process: - # yield - # self.wait_for_everyone() - # else: - # self.wait_for_everyone() - # yield - - def destroy(self): - if self.is_main_process: - self.tracker.finish() - return torch.distributed.destroy_process_group() - - @property - def pipeline_parallel_enabled(self): - return self._pp_degree > 1 - - @property - def data_parallel_enabled(self): - return self._dp_degree > 1 or self._dp_shards > 1 - - @property - def data_replication_enabled(self): - return self._dp_degree > 1 - - @property - def data_sharding_enabled(self): - return self._dp_shards > 1 - - @property - def context_parallel_enabled(self): - return self._cp_degree > 1 - - @property - def tensor_parallel_enabled(self): - return self._tp_degree > 1 diff --git a/finetrainers/parallel/utils.py b/finetrainers/parallel/utils.py deleted file mode 100644 index fe434e1a192504998260118e9bd4b615113b2cd8..0000000000000000000000000000000000000000 --- a/finetrainers/parallel/utils.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed._functional_collectives as funcol -import torch.distributed.tensor -from diffusers.utils import is_accelerate_available -from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard -from torch.distributed._composable.replicate import replicate - -from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES - - -if is_accelerate_available(): - from accelerate import Accelerator - from accelerate.utils import ( - DataLoaderConfiguration, - DistributedDataParallelKwargs, - InitProcessGroupKwargs, - ProjectConfiguration, - ) - - -def apply_fsdp2_ptd( - model: torch.nn.Module, - dp_mesh: torch.distributed.device_mesh.DeviceMesh, - param_dtype: torch.dtype, - reduce_dtype: torch.dtype, - output_dtype: torch.dtype, - pp_enabled: bool = False, - cpu_offload: bool = False, -) -> None: - r"""Apply FSDP2 on a model.""" - mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - - if cpu_offload: - fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) - - def apply_fully_shard(blocks): - for layer_index, block in enumerate(blocks): - if pp_enabled: - # For PP, do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = False - else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = layer_index < len(blocks) - 1 - fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) - - for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: - blocks = getattr(model, transformer_block_name, None) - if blocks is not None: - apply_fully_shard(blocks) - - fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - - -def apply_ddp_accelerate( - model: torch.nn.Module, - project_config: Optional[ProjectConfiguration] = None, - ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, - init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, - dataloader_config: Optional[DataLoaderConfiguration] = None, - gradient_accumulation_steps: Optional[int] = None, - accelerator: Optional[Accelerator] = None, -) -> torch.nn.Module: - if accelerator is None: - accelerator = Accelerator( - project_config=project_config, - dataloader_config=dataloader_config, - gradient_accumulation_steps=gradient_accumulation_steps, - log_with=None, - kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], - ) - if torch.backends.mps.is_available(): - accelerator.native_amp = False - accelerator.prepare_model(model) - return accelerator, model - - -def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: - replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) - - -def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: - if isinstance(x, torch.distributed.tensor.DTensor): - # functional collectives do not support DTensor inputs - x = x.full_tensor() - assert x.numel() == 1 # required by `.item()` - return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() - - -def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: - return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) - - -def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: - return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) diff --git a/finetrainers/patches/__init__.py b/finetrainers/patches/__init__.py deleted file mode 100644 index 26ce0ff4c134339655e3820a7dd84b2d09b4faf5..0000000000000000000000000000000000000000 --- a/finetrainers/patches/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import TYPE_CHECKING - - -if TYPE_CHECKING: - from ..args import BaseArgs - from ..parallel import ParallelBackendType - - -def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None: - # To avoid circular imports - from ..config import ModelType, TrainingType - - if args.model_name == ModelType.LTX_VIDEO: - from .models.ltx_video import patch - - patch.patch_transformer_forward() - if parallel_backend.tensor_parallel_enabled: - patch.patch_apply_rotary_emb_for_tp_compatibility() - - if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules: - from .models.wan import patch - - patch.patch_time_text_image_embedding_forward() - - if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0: - from .dependencies.peft import patch - - patch.patch_peft_move_adapter_to_device_of_base_layer() diff --git a/finetrainers/patches/dependencies/peft/patch.py b/finetrainers/patches/dependencies/peft/patch.py deleted file mode 100644 index 3c0bf968cf4894b8ccd90c76631e57541e10642d..0000000000000000000000000000000000000000 --- a/finetrainers/patches/dependencies/peft/patch.py +++ /dev/null @@ -1,25 +0,0 @@ -import functools - -from peft.tuners.tuners_utils import BaseTunerLayer - -from ...utils import DisableTensorToDtype - - -def patch_peft_move_adapter_to_device_of_base_layer() -> None: - _perform_patch_move_adapter_to_device_of_base_layer() - - -def _perform_patch_move_adapter_to_device_of_base_layer() -> None: - BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( - BaseTunerLayer._move_adapter_to_device_of_base_layer - ) - - -def _patched_move_adapter_to_device_of_base_layer(func) -> None: - # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor. - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - with DisableTensorToDtype(): - return func(self, *args, **kwargs) - - return wrapper diff --git a/finetrainers/patches/models/ltx_video/patch.py b/finetrainers/patches/models/ltx_video/patch.py deleted file mode 100644 index 9e8caa803f0716280ff066d6e7865746344fb8e9..0000000000000000000000000000000000000000 --- a/finetrainers/patches/models/ltx_video/patch.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Any, Dict, Optional, Tuple - -import diffusers -import torch -from diffusers import LTXVideoTransformer3DModel -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.utils.import_utils import is_torch_version - - -def patch_transformer_forward() -> None: - _perform_ltx_transformer_forward_patch() - - -def patch_apply_rotary_emb_for_tp_compatibility() -> None: - _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() - - -def _perform_ltx_transformer_forward_patch() -> None: - LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward - - -def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None: - def apply_rotary_emb(x, freqs): - cos, sin = freqs - # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ======== - # The change is made due to unsupported DTensor operation aten.ops.unbind - # FIXME: Once aten.ops.unbind support lands, this will no longer be required - # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] - x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1) # [B, S, H, D // 2] - # ================================================================== - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out - - diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb - - -def _patched_LTXVideoTransformer3D_forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: torch.LongTensor, - encoder_attention_mask: torch.Tensor, - num_frames: int, - height: int, - width: int, - rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, - return_dict: bool = True, - *args, - **kwargs, -) -> torch.Tensor: - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - batch_size = hidden_states.size(0) - - # ===== This is modified compared to Diffusers ===== - # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep - if timestep.ndim == 1: - timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1) - # ================================================== - - temb, embedded_timestep = self.time_embed( - timestep.flatten(), - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - # ===== This is modified compared to Diffusers ===== - # temb = temb.view(batch_size, -1, temb.size(-1)) - # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) - # ================================================== - # This is done to make it possible to use per-token timestep embedding - temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1)) - embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1)) - # ================================================== - - hidden_states = self.proj_in(hidden_states) - - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - encoder_attention_mask, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - ) - - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - output = self.proj_out(hidden_states) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) diff --git a/finetrainers/patches/models/wan/patch.py b/finetrainers/patches/models/wan/patch.py deleted file mode 100644 index e5c44ae42637fb9c6fc0a9803930f1728a92b693..0000000000000000000000000000000000000000 --- a/finetrainers/patches/models/wan/patch.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -import diffusers -import torch - - -def patch_time_text_image_embedding_forward() -> None: - _patch_time_text_image_embedding_forward() - - -def _patch_time_text_image_embedding_forward() -> None: - diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = ( - _patched_WanTimeTextImageEmbedding_forward - ) - - -def _patched_WanTimeTextImageEmbedding_forward( - self, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_hidden_states_image: Optional[torch.Tensor] = None, -): - # Some code has been removed compared to original implementation in Diffusers - # Also, timestep is typed as that of encoder_hidden_states - timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) - timestep_proj = self.time_proj(self.act_fn(temb)) - - encoder_hidden_states = self.text_embedder(encoder_hidden_states) - if encoder_hidden_states_image is not None: - encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) - - return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image diff --git a/finetrainers/patches/utils.py b/finetrainers/patches/utils.py deleted file mode 100644 index 9d7f4726cc8183a461310570762ee95b5c4e6187..0000000000000000000000000000000000000000 --- a/finetrainers/patches/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch - - -class DisableTensorToDtype: - def __enter__(self): - self.original_to = torch.Tensor.to - - def modified_to(tensor, *args, **kwargs): - # remove dtype from args if present - args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] - if "dtype" in kwargs: - kwargs.pop("dtype") - return self.original_to(tensor, *args, **kwargs) - - torch.Tensor.to = modified_to - - def __exit__(self, *args, **kwargs): - torch.Tensor.to = self.original_to diff --git a/finetrainers/processors/__init__.py b/finetrainers/processors/__init__.py deleted file mode 100644 index 0df7216d329da5259a4d18ac4489de7aaf2542af..0000000000000000000000000000000000000000 --- a/finetrainers/processors/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .base import ProcessorMixin -from .clip import CLIPPooledProcessor -from .glm import CogView4GLMProcessor -from .llama import LlamaProcessor -from .t5 import T5Processor -from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor diff --git a/finetrainers/processors/base.py b/finetrainers/processors/base.py deleted file mode 100644 index 9853ead0ef49610bdfed05ff067918bf80558109..0000000000000000000000000000000000000000 --- a/finetrainers/processors/base.py +++ /dev/null @@ -1,20 +0,0 @@ -import inspect -from typing import Any, Dict, List - - -class ProcessorMixin: - def __init__(self) -> None: - self._forward_parameter_names = inspect.signature(self.forward).parameters.keys() - self.output_names: List[str] = None - self.input_names: Dict[str, Any] = None - - def __call__(self, *args, **kwargs) -> Any: - shallow_copy_kwargs = dict(kwargs.items()) - if self.input_names is not None: - for k, v in self.input_names.items(): - shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k) - acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names} - return self.forward(*args, **acceptable_kwargs) - - def forward(self, *args, **kwargs) -> Any: - raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.") diff --git a/finetrainers/processors/clip.py b/finetrainers/processors/clip.py deleted file mode 100644 index 178addf8b2556e7c1ae952084790f2575d27f007..0000000000000000000000000000000000000000 --- a/finetrainers/processors/clip.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast - -from .base import ProcessorMixin - - -class CLIPPooledProcessor(ProcessorMixin): - r""" - Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings - and attention masks for the input text. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor should return. The first output is the embeddings of the input - text and the second output is the attention mask for the input text. - """ - - def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None: - super().__init__() - - self.output_names = output_names - self.input_names = input_names - - assert len(output_names) == 1 - if input_names is not None: - assert len(input_names) <= 3 - - def forward( - self, - tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], - text_encoder: CLIPTextModel, - caption: Union[str, List[str]], - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Encode the input text and return the embeddings and attention mask for the input text. - - Args: - tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): - The tokenizer used to tokenize the input text. - text_encoder (`LlamaModel`): - The text encoder used to encode the input text. - caption (`Union[str, List[str]]`): - The input text to be encoded. - """ - if isinstance(caption, str): - caption = [caption] - - device = text_encoder.device - dtype = text_encoder.dtype - - text_inputs = tokenizer( - caption, - padding="max_length", - max_length=77, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids.to(device) - - prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - return {self.output_names[0]: prompt_embeds} diff --git a/finetrainers/processors/glm.py b/finetrainers/processors/glm.py deleted file mode 100644 index bf742130bb7da8808710ec562c85d9c64a535cb6..0000000000000000000000000000000000000000 --- a/finetrainers/processors/glm.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import List, Tuple, Union - -import torch -from transformers import AutoTokenizer, GlmModel - -from .base import ProcessorMixin - - -class CogView4GLMProcessor(ProcessorMixin): - r""" - Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings - and attention masks for the input text. - - This processor is specific to CogView4 but can be used with any other model. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor should return. The first output is the embeddings of the input - text and the second output is the attention mask for the input text. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - - self.output_names = output_names - - assert len(self.output_names) == 1 - - def forward( - self, - tokenizer: AutoTokenizer, - text_encoder: GlmModel, - caption: Union[str, List[str]], - max_sequence_length: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Encode the input text and return the embeddings and attention mask for the input text. - - Args: - tokenizer (`AutoTokenizer`): - The tokenizer used to tokenize the input text. - text_encoder (`GlmModel`): - The text encoder used to encode the input text. - caption (`Union[str, List[str]]`): - The input text to be encoded. - max_sequence_length (`int`): - The maximum sequence length of the input text. - """ - if isinstance(caption, str): - caption = [caption] - - device = text_encoder.device - dtype = text_encoder.dtype - - text_inputs = tokenizer( - caption, - padding="longest", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids.to(device) - - current_length = text_input_ids.size(1) - pad_length = 16 - current_length % 16 - if pad_length > 0: - pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id) - text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - - prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - return {self.output_names[0]: prompt_embeds} diff --git a/finetrainers/processors/llama.py b/finetrainers/processors/llama.py deleted file mode 100644 index 749e5f313541b92317279669faf915edeb9129c4..0000000000000000000000000000000000000000 --- a/finetrainers/processors/llama.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast - -from .base import ProcessorMixin - - -DEFAULT_PROMPT_TEMPLATE = { - "template": ( - "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." - "4. background environment, light, style and atmosphere." - "5. camera angles, movements, and transitions used in the video:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - ), - "crop_start": 95, -} - - -class LlamaProcessor(ProcessorMixin): - r""" - Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings - and attention masks for the input text. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor should return. The first output is the embeddings of the input - text and the second output is the attention mask for the input text. - """ - - def __init__(self, output_names: List[str] = None): - super().__init__() - - self.output_names = output_names - - assert len(output_names) == 2 - - def forward( - self, - tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast], - text_encoder: LlamaModel, - caption: Union[str, List[str]], - max_sequence_length: int, - prompt_template: Optional[Dict[str, Any]] = None, - num_layers_to_skip: int = 2, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Encode the input text and return the embeddings and attention mask for the input text. - - Args: - tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): - The tokenizer used to tokenize the input text. - text_encoder (`LlamaModel`): - The text encoder used to encode the input text. - caption (`Union[str, List[str]]`): - The input text to be encoded. - max_sequence_length (`int`): - The maximum sequence length of the input text. - prompt_template (`Optional[Dict[str, Any]]`): - The prompt template to be used to encode the input text. - """ - if prompt_template is None: - prompt_template = DEFAULT_PROMPT_TEMPLATE - if isinstance(caption, str): - caption = [caption] - - device = text_encoder.device - dtype = text_encoder.dtype - - batch_size = len(caption) - caption = [prompt_template["template"].format(c) for c in caption] - - crop_start = prompt_template.get("crop_start", None) - if crop_start is None: - prompt_template_input = tokenizer( - prompt_template["template"], - padding="max_length", - return_tensors="pt", - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=False, - ) - crop_start = prompt_template_input["input_ids"].shape[-1] - # Remove <|eot_id|> token and placeholder {} - crop_start -= 2 - - max_sequence_length += crop_start - text_inputs = tokenizer( - caption, - max_length=max_sequence_length, - padding="max_length", - truncation=True, - return_tensors="pt", - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=True, - ) - text_input_ids = text_inputs.input_ids.to(device) - prompt_attention_mask = text_inputs.attention_mask.bool().to(device) - - prompt_embeds = text_encoder( - text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True - ).hidden_states[-(num_layers_to_skip + 1)] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - if crop_start is not None and crop_start > 0: - prompt_embeds = prompt_embeds[:, crop_start:] - prompt_attention_mask = prompt_attention_mask[:, crop_start:] - - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - - return { - self.output_names[0]: prompt_embeds, - self.output_names[1]: prompt_attention_mask, - } diff --git a/finetrainers/processors/t5.py b/finetrainers/processors/t5.py deleted file mode 100644 index 96c2c194ca01635de404d618afffa93b88cdf953..0000000000000000000000000000000000000000 --- a/finetrainers/processors/t5.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import List, Tuple, Union - -import torch -from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast - -from .base import ProcessorMixin - - -class T5Processor(ProcessorMixin): - r""" - Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings - and attention masks for the input text. - - Args: - output_names (`List[str]`): - The names of the outputs that the processor should return. The first output is the embeddings of the input - text and the second output is the attention mask for the input text. - """ - - def __init__(self, output_names: List[str]): - super().__init__() - - self.output_names = output_names - - assert len(self.output_names) == 2 - - def forward( - self, - tokenizer: Union[T5Tokenizer, T5TokenizerFast], - text_encoder: T5EncoderModel, - caption: Union[str, List[str]], - max_sequence_length: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Encode the input text and return the embeddings and attention mask for the input text. - - Args: - tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`): - The tokenizer used to tokenize the input text. - text_encoder (`T5EncoderModel`): - The text encoder used to encode the input text. - caption (`Union[str, List[str]]`): - The input text to be encoded. - max_sequence_length (`int`): - The maximum sequence length of the input text. - """ - if isinstance(caption, str): - caption = [caption] - - device = text_encoder.device - dtype = text_encoder.dtype - - batch_size = len(caption) - text_inputs = tokenizer( - caption, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.bool().to(device) - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - - return { - self.output_names[0]: prompt_embeds, - self.output_names[1]: prompt_attention_mask, - } diff --git a/finetrainers/processors/text.py b/finetrainers/processors/text.py deleted file mode 100644 index b51dca68214ba36c5813775f6dbc6d40592e9b3c..0000000000000000000000000000000000000000 --- a/finetrainers/processors/text.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import List, Union - -import torch - -from .. import functional as FF -from .base import ProcessorMixin - - -class CaptionTextDropoutProcessor(ProcessorMixin): - def __init__(self, dropout_p: float = 0.0) -> None: - self.dropout_p = dropout_p - - def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]: - return FF.dropout_caption(caption, self.dropout_p) - - -class CaptionEmbeddingDropoutProcessor(ProcessorMixin): - def __init__(self, dropout_p: float = 0.0) -> None: - self.dropout_p = dropout_p - - def forward(self, embedding: torch.Tensor) -> torch.Tensor: - return FF.dropout_embeddings_to_zero(embedding, self.dropout_p) diff --git a/finetrainers/state.py b/finetrainers/state.py deleted file mode 100644 index 5cda7d91b3f0e82b493d7e88b3565b9df985a228..0000000000000000000000000000000000000000 --- a/finetrainers/state.py +++ /dev/null @@ -1,69 +0,0 @@ -import io -from dataclasses import dataclass, field -from typing import Any, Dict, List - -import torch -import torch.distributed.checkpoint.stateful - -from .parallel import ParallelBackendType -from .utils import get_device_info - - -_device_type, _ = get_device_info() - - -@dataclass -class TrainState(torch.distributed.checkpoint.stateful.Stateful): - step: int = 0 - observed_data_samples: int = 0 - observed_num_tokens: int = 0 - global_avg_losses: List[float] = field(default_factory=list) - global_max_losses: List[float] = field(default_factory=list) - log_steps: List[int] = field(default_factory=list) - - def state_dict(self) -> Dict[str, Any]: - # Only checkpoint global_avg_losses and global_max_losses per log frequency - # to avoid sync overhead in every iteration. - global_avg_losses_bytes = io.BytesIO() - torch.save(self.global_avg_losses, global_avg_losses_bytes) - global_max_losses_bytes = io.BytesIO() - torch.save(self.global_max_losses, global_max_losses_bytes) - log_steps_bytes = io.BytesIO() - torch.save(self.log_steps, log_steps_bytes) - return { - "step": torch.tensor(self.step, dtype=torch.int32), - "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32), - "observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32), - "global_avg_losses": global_avg_losses_bytes, - "global_max_losses": global_max_losses_bytes, - "log_steps": log_steps_bytes, - } - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - state_dict["global_avg_losses"].seek(0) - state_dict["global_max_losses"].seek(0) - state_dict["log_steps"].seek(0) - - self.step = state_dict["step"].item() - self.observed_data_samples = state_dict["observed_data_samples"].item() - self.observed_num_tokens = state_dict["observed_num_tokens"].item() - self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False) - self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False) - self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) - - -@dataclass -class State: - # Parallel state - parallel_backend: ParallelBackendType = None - - # Training state - train_state: TrainState = None - num_trainable_parameters: int = 0 - generator: torch.Generator = None - - # Hub state - repo_id: str = None - - # Artifacts state - output_dir: str = None diff --git a/finetrainers/trackers.py b/finetrainers/trackers.py deleted file mode 100644 index a48716605e1ed2e39ab5d86cd39f72467496fd52..0000000000000000000000000000000000000000 --- a/finetrainers/trackers.py +++ /dev/null @@ -1,92 +0,0 @@ -import pathlib -from enum import Enum -from typing import Any, Dict, List, Optional, Union - -from .logging import get_logger - - -logger = get_logger() - - -class BaseTracker: - r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging.""" - - def log(self, metrics: Dict[str, Any], step: int) -> None: - pass - - def finish(self) -> None: - pass - - -class WandbTracker(BaseTracker): - r"""Logger implementation for Weights & Biases.""" - - def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: - import wandb - - self.wandb = wandb - - # WandB does not create a directory if it does not exist and instead starts using the system temp directory. - pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) - - self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) - logger.info("WandB logging enabled") - - def log(self, metrics: Dict[str, Any], step: int) -> None: - self.run.log(metrics, step=step) - - def finish(self) -> None: - self.run.finish() - - -class SequentialTracker(BaseTracker): - r"""Sequential tracker that logs to multiple trackers in sequence.""" - - def __init__(self, trackers: List[BaseTracker]) -> None: - self.trackers = trackers - - def log(self, metrics: Dict[str, Any], step: int) -> None: - for tracker in self.trackers: - tracker.log(metrics, step) - - def finish(self) -> None: - for tracker in self.trackers: - tracker.finish() - - -class Trackers(str, Enum): - r"""Enum for supported trackers.""" - - NONE = "none" - WANDB = "wandb" - - -_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()] - - -def initialize_trackers( - trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str -) -> Union[BaseTracker, SequentialTracker]: - r"""Initialize loggers based on the provided configuration.""" - - logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}") - - if len(trackers) == 0: - return BaseTracker() - - if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)): - raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}") - - tracker_instances = [] - for tracker_name in set(trackers): - if tracker_name == Trackers.NONE: - tracker = BaseTracker() - elif tracker_name == Trackers.WANDB: - tracker = WandbTracker(experiment_name, log_dir, config) - tracker_instances.append(tracker) - - tracker = SequentialTracker(tracker_instances) - return tracker - - -TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker] diff --git a/finetrainers/trainer/__init__.py b/finetrainers/trainer/__init__.py deleted file mode 100644 index 6ba65c53fa75115084ceec6dabc2667a8a5d6a29..0000000000000000000000000000000000000000 --- a/finetrainers/trainer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sft_trainer.trainer import SFTTrainer diff --git a/finetrainers/trainer/config_utils.py b/finetrainers/trainer/config_utils.py deleted file mode 100644 index 5c354a30bac60265868c3ba8ef5313c12b7fc224..0000000000000000000000000000000000000000 --- a/finetrainers/trainer/config_utils.py +++ /dev/null @@ -1,17 +0,0 @@ -import argparse -from typing import TYPE_CHECKING - - -if TYPE_CHECKING: - from ..args import BaseArgs - - -class ConfigMixin: - def add_args(self, parser: argparse.ArgumentParser): - raise NotImplementedError("ConfigMixin::add_args should be implemented by subclasses.") - - def validate_args(self, args: "BaseArgs"): - raise NotImplementedError("ConfigMixin::map_args should be implemented by subclasses.") - - def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): - raise NotImplementedError("ConfigMixin::validate_args should be implemented by subclasses.") diff --git a/finetrainers/trainer/sft_trainer/config.py b/finetrainers/trainer/sft_trainer/config.py deleted file mode 100644 index 539426c7c9509c1b417d3471cc4391436f532cb7..0000000000000000000000000000000000000000 --- a/finetrainers/trainer/sft_trainer/config.py +++ /dev/null @@ -1,58 +0,0 @@ -import argparse -from typing import TYPE_CHECKING, List, Union - -from ..config_utils import ConfigMixin - - -if TYPE_CHECKING: - from ...args import BaseArgs - - -class SFTLowRankConfig(ConfigMixin): - r""" - Configuration class for SFT low rank training. - - Args: - rank (int): - Rank of the low rank approximation. - lora_alpha (int): - The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. - target_modules (`str` or `List[str]`): - Target modules for the low rank approximation. Can be a regex string or a list of regex strings. - """ - - rank: int = 64 - lora_alpha: int = 64 - target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" - - def add_args(self, parser: argparse.ArgumentParser): - parser.add_argument("--rank", type=int, default=64) - parser.add_argument("--lora_alpha", type=int, default=64) - parser.add_argument( - "--target_modules", - type=str, - nargs="+", - default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"], - ) - - def validate_args(self, args: "BaseArgs"): - assert self.rank > 0, "Rank must be a positive integer." - assert self.lora_alpha > 0, "lora_alpha must be a positive integer." - - def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): - mapped_args.rank = argparse_args.rank - mapped_args.lora_alpha = argparse_args.lora_alpha - mapped_args.target_modules = ( - argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules - ) - - -class SFTFullRankConfig(ConfigMixin): - def add_args(self, parser: argparse.ArgumentParser): - pass - - def validate_args(self, args: "BaseArgs"): - pass - - def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): - pass diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py deleted file mode 100644 index 28c94387022e95a8b7ab617e5201072b6c3fe837..0000000000000000000000000000000000000000 --- a/finetrainers/trainer/sft_trainer/trainer.py +++ /dev/null @@ -1,989 +0,0 @@ -import functools -import json -import math -import os -import time -from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union - -import datasets.distributed -import diffusers -import torch -import torch.backends -import transformers -import wandb -from diffusers import DiffusionPipeline -from diffusers.hooks import apply_layerwise_casting -from diffusers.training_utils import cast_training_params -from diffusers.utils import export_to_video -from huggingface_hub import create_repo, upload_folder -from peft import LoraConfig, get_peft_model_state_dict -from tqdm import tqdm - -from ... import data, logging, optimizer, parallel, patches, utils -from ...config import TrainingType -from ...state import State, TrainState - - -if TYPE_CHECKING: - from ...args import BaseArgs - from ...models import ModelSpecification - - -logger = logging.get_logger() - - -class SFTTrainer: - # fmt: off - _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"] - _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"] - _latent_component_names = ["vae"] - _diffusion_component_names = ["transformer", "unet", "scheduler"] - # fmt: on - - def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None: - self.args = args - self.state = State() - self.state.train_state = TrainState() - - # Tokenizers - self.tokenizer = None - self.tokenizer_2 = None - self.tokenizer_3 = None - - # Text encoders - self.text_encoder = None - self.text_encoder_2 = None - self.text_encoder_3 = None - - # Denoisers - self.transformer = None - self.unet = None - - # Autoencoders - self.vae = None - - # Scheduler - self.scheduler = None - - # Optimizer & LR scheduler - self.optimizer = None - self.lr_scheduler = None - - # Checkpoint manager - self.checkpointer = None - - self._init_distributed() - self._init_config_options() - - # Perform any patches that might be necessary for training to work as expected - patches.perform_patches_for_training(self.args, self.state.parallel_backend) - - self.model_specification = model_specification - self._are_condition_models_loaded = False - - def run(self) -> None: - try: - self._prepare_models() - self._prepare_trainable_parameters() - self._prepare_for_training() - self._prepare_dataset() - self._prepare_checkpointing() - self._train() - # trainer._evaluate() - except Exception as e: - logger.error(f"Error during training: {e}") - self.state.parallel_backend.destroy() - raise e - - def _prepare_models(self) -> None: - logger.info("Initializing models") - - diffusion_components = self.model_specification.load_diffusion_models() - self._set_components(diffusion_components) - - if self.state.parallel_backend.pipeline_parallel_enabled: - raise NotImplementedError( - "Pipeline parallelism is not supported yet. This will be supported in the future." - ) - - def _prepare_trainable_parameters(self) -> None: - logger.info("Initializing trainable parameters") - - parallel_backend = self.state.parallel_backend - - if self.args.training_type == TrainingType.FULL_FINETUNE: - logger.info("Finetuning transformer with no additional parameters") - utils.set_requires_grad([self.transformer], True) - else: - logger.info("Finetuning transformer with PEFT parameters") - utils.set_requires_grad([self.transformer], False) - - # Layerwise upcasting must be applied before adding the LoRA adapter. - # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on - # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. - if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: - apply_layerwise_casting( - self.transformer, - storage_dtype=self.args.layerwise_upcasting_storage_dtype, - compute_dtype=self.args.transformer_dtype, - skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, - non_blocking=True, - ) - - transformer_lora_config = None - if self.args.training_type == TrainingType.LORA: - transformer_lora_config = LoraConfig( - r=self.args.rank, - lora_alpha=self.args.lora_alpha, - init_lora_weights=True, - target_modules=self.args.target_modules, - ) - self.transformer.add_adapter(transformer_lora_config) - - # # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 - # # even if layerwise upcasting. Would be nice to have a test as well - # self.register_saving_loading_hooks(transformer_lora_config) - - # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all - # parameters to be of the same dtype. - if parallel_backend.data_sharding_enabled: - self.transformer.to(dtype=self.args.transformer_dtype) - else: - if self.args.training_type == TrainingType.LORA: - cast_training_params([self.transformer], dtype=torch.float32) - - def _prepare_for_training(self) -> None: - # 1. Apply parallelism - parallel_backend = self.state.parallel_backend - world_mesh = parallel_backend.get_mesh() - model_specification = self.model_specification - - if parallel_backend.context_parallel_enabled: - raise NotImplementedError( - "Context parallelism is not supported yet. This will be supported in the future." - ) - - if parallel_backend.tensor_parallel_enabled: - # TODO(aryan): handle fp8 from TorchAO here - model_specification.apply_tensor_parallel( - backend=parallel.ParallelBackendEnum.PTD, - device_mesh=parallel_backend.get_mesh()["tp"], - transformer=self.transformer, - ) - - # Enable gradient checkpointing - if self.args.gradient_checkpointing: - # TODO(aryan): support other checkpointing types - utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full") - - # Enable DDP, FSDP or HSDP - if parallel_backend.data_sharding_enabled: - # TODO(aryan): remove this when supported - if self.args.parallel_backend == "accelerate": - raise NotImplementedError("Data sharding is not supported with Accelerate yet.") - - if parallel_backend.data_replication_enabled: - logger.info("Applying HSDP to the model") - else: - logger.info("Applying FSDP to the model") - - # Apply FSDP or HSDP - if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled: - dp_mesh_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_names = ("dp_shard_cp",) - - parallel.apply_fsdp2_ptd( - model=self.transformer, - dp_mesh=world_mesh[dp_mesh_names], - param_dtype=self.args.transformer_dtype, - reduce_dtype=torch.float32, - output_dtype=None, - pp_enabled=parallel_backend.pipeline_parallel_enabled, - cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later - ) - elif parallel_backend.data_replication_enabled: - logger.info("Applying DDP to the model") - - if world_mesh.ndim > 1: - raise ValueError("DDP not supported for > 1D parallelism") - - parallel_backend.apply_ddp(self.transformer, world_mesh) - - self._move_components_to_device() - - # 2. Prepare optimizer and lr scheduler - # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module. - # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer - # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99) - # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require - # gradients. TODO(aryan): look into it in the future. - model_parts = [self.transformer] - self.state.num_trainable_parameters = sum( - p.numel() for m in model_parts for p in m.parameters() if p.requires_grad - ) - - # Setup distributed optimizer and lr scheduler - logger.info("Initializing optimizer and lr scheduler") - self.state.train_state = TrainState() - self.optimizer = optimizer.get_optimizer( - parallel_backend=self.args.parallel_backend, - name=self.args.optimizer, - model_parts=model_parts, - learning_rate=self.args.lr, - beta1=self.args.beta1, - beta2=self.args.beta2, - beta3=self.args.beta3, - epsilon=self.args.epsilon, - weight_decay=self.args.weight_decay, - fused=False, - ) - self.lr_scheduler = optimizer.get_lr_scheduler( - parallel_backend=self.args.parallel_backend, - name=self.args.lr_scheduler, - optimizer=self.optimizer, - num_warmup_steps=self.args.lr_warmup_steps, - num_training_steps=self.args.train_steps, - # TODO(aryan): handle last_epoch - ) - self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler) - - # 3. Initialize trackers, directories and repositories - self._init_logging() - self._init_trackers() - self._init_directories_and_repositories() - - def _prepare_dataset(self) -> None: - logger.info("Initializing dataset and dataloader") - - with open(self.args.dataset_config, "r") as file: - dataset_configs = json.load(file)["datasets"] - logger.info(f"Training configured to use {len(dataset_configs)} datasets") - - datasets = [] - for config in dataset_configs: - data_root = config.pop("data_root", None) - dataset_file = config.pop("dataset_file", None) - dataset_type = config.pop("dataset_type") - caption_options = config.pop("caption_options", {}) - - if data_root is not None and dataset_file is not None: - raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.") - - dataset_name_or_root = data_root or dataset_file - dataset = data.initialize_dataset( - dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options - ) - - if not dataset._precomputable_once and self.args.precomputation_once: - raise ValueError( - f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once." - ) - - logger.info(f"Initialized dataset: {dataset_name_or_root}") - dataset = self.state.parallel_backend.prepare_dataset(dataset) - dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config) - datasets.append(dataset) - - dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True) - dataloader = self.state.parallel_backend.prepare_dataloader( - dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory - ) - - self.dataset = dataset - self.dataloader = dataloader - - def _prepare_checkpointing(self) -> None: - parallel_backend = self.state.parallel_backend - - def save_model_hook(state_dict: Dict[str, Any]) -> None: - if parallel_backend.is_main_process: - if self.args.training_type == TrainingType.LORA: - state_dict = get_peft_model_state_dict(self.transformer, state_dict) - self.model_specification._save_lora_weights(self.args.output_dir, state_dict, self.scheduler) - elif self.args.training_type == TrainingType.FULL_FINETUNE: - self.model_specification._save_model( - self.args.output_dir, self.transformer, state_dict, self.scheduler - ) - parallel_backend.wait_for_everyone() - - enable_state_checkpointing = self.args.checkpointing_steps > 0 - self.checkpointer = utils.PTDCheckpointManager( - dataloader=self.dataloader, - model_parts=[self.transformer], - optimizers=self.optimizer, - schedulers=self.lr_scheduler, - states={"train_state": self.state.train_state}, - checkpointing_steps=self.args.checkpointing_steps, - checkpointing_limit=self.args.checkpointing_limit, - output_dir=self.args.output_dir, - enable=enable_state_checkpointing, - _callback_fn=save_model_hook, - ) - - resume_from_checkpoint = self.args.resume_from_checkpoint - if resume_from_checkpoint == "latest": - resume_from_checkpoint = -1 - if resume_from_checkpoint is not None: - self.checkpointer.load(resume_from_checkpoint) - - def _train(self) -> None: - logger.info("Starting training") - - parallel_backend = self.state.parallel_backend - train_state = self.state.train_state - device = parallel_backend.device - dtype = self.args.transformer_dtype - - memory_statistics = utils.get_memory_statistics() - logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") - - global_batch_size = self.args.batch_size * parallel_backend._dp_degree - info = { - "trainable parameters": self.state.num_trainable_parameters, - "train steps": self.args.train_steps, - "per-replica batch size": self.args.batch_size, - "global batch size": global_batch_size, - "gradient accumulation steps": self.args.gradient_accumulation_steps, - } - logger.info(f"Training configuration: {json.dumps(info, indent=4)}") - - progress_bar = tqdm( - range(0, self.args.train_steps), - initial=train_state.step, - desc="Training steps", - disable=not parallel_backend.is_local_main_process, - ) - - generator = torch.Generator(device=device) - if self.args.seed is not None: - generator = generator.manual_seed(self.args.seed) - self.state.generator = generator - - patch_size = 1 - if ( - getattr(self.transformer.config, "patch_size", None) is not None - and getattr(self.transformer.config, "patch_size_t", None) is not None - ): - patch_size = self.transformer.config.patch_size * self.transformer.config.patch_size_t - elif isinstance(getattr(self.transformer.config, "patch_size", None), int): - patch_size = self.transformer.config.patch_size - elif isinstance(getattr(self.transformer.config, "patch_size", None), (list, tuple)): - patch_size = math.prod(self.transformer.config.patch_size) - - scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler) - scheduler_sigmas = ( - scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None - ) - scheduler_alphas = utils.get_scheduler_alphas(self.scheduler) - scheduler_alphas = ( - scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None - ) - timesteps_buffer = [] - - self.transformer.train() - data_iterator = iter(self.dataloader) - - preprocessor = data.initialize_preprocessor( - rank=parallel_backend.rank, - num_items=self.args.precomputation_items if self.args.enable_precomputation else 1, - processor_fn={ - "condition": self.model_specification.prepare_conditions, - "latent": functools.partial( - self.model_specification.prepare_latents, compute_posterior=not self.args.precomputation_once - ), - }, - save_dir=self.args.precomputation_dir, - enable_precomputation=self.args.enable_precomputation, - ) - precomputed_condition_iterator: Iterable[Dict[str, Any]] = None - precomputed_latent_iterator: Iterable[Dict[str, Any]] = None - sampler = data.ResolutionSampler( - batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys - ) - requires_gradient_step = True - accumulated_loss = 0.0 - - while ( - train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples - ): - # 1. Load & preprocess data if required - if preprocessor.requires_data: - # TODO(aryan): We should do the following here: - # - Force checkpoint the trainable models, optimizers, schedulers and train state - # - Do the precomputation - # - Load the checkpointed models, optimizers, schedulers and train state back, and continue training - # This way we can be more memory efficient again, since the latest rewrite of precomputation removed - # this logic. - precomputed_condition_iterator, precomputed_latent_iterator = self._prepare_data( - preprocessor, data_iterator - ) - - # 2. Prepare batch - try: - condition_item = next(precomputed_condition_iterator) - latent_item = next(precomputed_latent_iterator) - sampler.consume(condition_item, latent_item) - except StopIteration: - if requires_gradient_step: - self.optimizer.step() - self.lr_scheduler.step() - requires_gradient_step = False - logger.info("Data exhausted. Exiting training loop.") - break - - if sampler.is_ready: - condition_batch, latent_batch = sampler.get_batch() - condition_model_conditions = self.model_specification.collate_conditions(condition_batch) - latent_model_conditions = self.model_specification.collate_latents(latent_batch) - else: - continue - - train_state.step += 1 - train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree - - lmc_latents = latent_model_conditions["latents"] - train_state.observed_num_tokens += math.prod(lmc_latents.shape[:-1]) // patch_size - - logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") - - latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype) - condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype) - latent_model_conditions = utils.make_contiguous(latent_model_conditions) - condition_model_conditions = utils.make_contiguous(condition_model_conditions) - - # 3. Forward pass - sigmas = utils.prepare_sigmas( - scheduler=self.scheduler, - sigmas=scheduler_sigmas, - batch_size=self.args.batch_size, - num_train_timesteps=self.scheduler.config.num_train_timesteps, - flow_weighting_scheme=self.args.flow_weighting_scheme, - flow_logit_mean=self.args.flow_logit_mean, - flow_logit_std=self.args.flow_logit_std, - flow_mode_scale=self.args.flow_mode_scale, - device=device, - generator=self.state.generator, - ) - sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim) - - pred, target, sigmas = self.model_specification.forward( - transformer=self.transformer, - scheduler=self.scheduler, - condition_model_conditions=condition_model_conditions, - latent_model_conditions=latent_model_conditions, - sigmas=sigmas, - compute_posterior=not self.args.precomputation_once, - ) - - timesteps = (sigmas * 1000.0).long() - weights = utils.prepare_loss_weights( - scheduler=self.scheduler, - alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, - sigmas=sigmas, - flow_weighting_scheme=self.args.flow_weighting_scheme, - ) - weights = utils.expand_tensor_dims(weights, pred.ndim) - - # 4. Compute loss & backward pass - loss = weights.float() * (pred.float() - target.float()).pow(2) - # Average loss across all but batch dimension - loss = loss.mean(list(range(1, loss.ndim))) - # Average loss across batch dimension - loss = loss.mean() - if self.args.gradient_accumulation_steps > 1: - loss = loss / self.args.gradient_accumulation_steps - loss.backward() - accumulated_loss += loss.detach().item() - requires_gradient_step = True - - # 5. Clip gradients - model_parts = [self.transformer] - grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases( - [p for m in model_parts for p in m.parameters()], - self.args.max_grad_norm, - foreach=True, - pp_mesh=parallel_backend.get_mesh("pp") if parallel_backend.pipeline_parallel_enabled else None, - ) - - # 6. Step optimizer & log metrics - logs = {} - - if train_state.step % self.args.gradient_accumulation_steps == 0: - # TODO(aryan): revisit no_sync() for FSDP - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - - if grad_norm is not None: - logs["grad_norm"] = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item() - if ( - parallel_backend.data_replication_enabled - or parallel_backend.data_sharding_enabled - or parallel_backend.context_parallel_enabled - ): - dp_cp_mesh = parallel_backend.get_mesh("dp_cp") - global_avg_loss, global_max_loss = ( - parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), - parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), - ) - else: - global_avg_loss = global_max_loss = accumulated_loss - - logs["global_avg_loss"] = global_avg_loss - logs["global_max_loss"] = global_max_loss - train_state.global_avg_losses.append(global_avg_loss) - train_state.global_max_losses.append(global_max_loss) - accumulated_loss = 0.0 - requires_gradient_step = False - - progress_bar.update(1) - progress_bar.set_postfix(logs) - - timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()]) - - if train_state.step % self.args.logging_steps == 0: - # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts - # TODO(aryan): causes NCCL hang for some reason. look into later - # logs.update(self.lr_scheduler.get_last_lr()) - - # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"]) - # logs["timesteps"] = wandb.plot.scatter( - # timesteps_table, "step", "timesteps", title="Timesteps distribution" - # ) - timesteps_buffer = [] - - logs["observed_data_samples"] = train_state.observed_data_samples - logs["observed_num_tokens"] = train_state.observed_num_tokens - - parallel_backend.log(logs, step=train_state.step) - train_state.log_steps.append(train_state.step) - - # 7. Save checkpoint if required - self.checkpointer.save( - step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process - ) - - # 8. Perform validation if required - if train_state.step % self.args.validation_steps == 0: - self._validate(step=train_state.step, final_validation=False) - - # 9. Final checkpoint, validation & cleanup - self.checkpointer.save( - train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process - ) - parallel_backend.wait_for_everyone() - self._validate(step=train_state.step, final_validation=True) - - self._delete_components() - memory_statistics = utils.get_memory_statistics() - logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") - - # 10. Upload artifacts to hub - if parallel_backend.is_main_process and self.args.push_to_hub: - upload_folder( - repo_id=self.state.repo_id, - folder_path=self.args.output_dir, - ignore_patterns=[f"{self.checkpointer._prefix}_*"], - ) - - parallel_backend.destroy() - - def _validate(self, step: int, final_validation: bool = False) -> None: - if self.args.validation_dataset_file is None: - return - - logger.info("Starting validation") - - # 1. Load validation dataset - parallel_backend = self.state.parallel_backend - dp_mesh = parallel_backend.get_mesh("dp_replicate") - - if dp_mesh is not None: - local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() - else: - local_rank, dp_world_size = 0, 1 - - dataset = data.ValidationDataset(self.args.validation_dataset_file) - dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, local_rank, dp_world_size) - validation_dataloader = data.DPDataLoader( - local_rank, - dataset, - batch_size=1, - num_workers=self.args.dataloader_num_workers, - collate_fn=lambda items: items, - ) - data_iterator = iter(validation_dataloader) - main_process_prompts_to_filenames = {} # Used to save model card - all_processes_artifacts = [] # Used to gather artifacts from all processes - - memory_statistics = utils.get_memory_statistics() - logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") - - seed = self.args.seed if self.args.seed is not None else 0 - generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) - pipeline = self._init_pipeline(final_validation=final_validation) - - # 2. Run validation - # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we - # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset - # size is not divisible by dp_shards. - self.transformer.eval() - while True: - validation_data = next(data_iterator, None) - if validation_data is None: - break - - logger.debug( - f"Validating {validation_data=} on rank={parallel_backend.rank}.", local_main_process_only=False - ) - - validation_data = validation_data[0] - validation_artifacts = self.model_specification.validation( - pipeline=pipeline, generator=generator, **validation_data - ) - - PROMPT = validation_data["prompt"] - IMAGE = validation_data.get("image", None) - VIDEO = validation_data.get("video", None) - EXPORT_FPS = validation_data.get("export_fps", 30) - - # 2.1. If there are any initial images or videos, they will be logged to keep track of them as - # conditioning for generation. - prompt_filename = utils.string_to_filename(PROMPT)[:25] - artifacts = { - "input_image": data.ImageArtifact(value=IMAGE), - "input_video": data.VideoArtifact(value=VIDEO), - } - - # 2.2. Track the artifacts generated from validation - for i, validation_artifact in enumerate(validation_artifacts): - if validation_artifact.value is None: - continue - artifacts.update({f"artifact_{i}": validation_artifact}) - - # 2.3. Save the artifacts to the output directory and create appropriate logging objects - # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. - for index, (key, artifact) in enumerate(list(artifacts.items())): - assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) - - time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension - filename = "validation-" if not final_validation else "final-" - filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}" - output_filename = os.path.join(self.args.output_dir, filename) - - if parallel_backend.is_main_process and artifact.file_extension == "mp4": - main_process_prompts_to_filenames[PROMPT] = filename - - if artifact.type == "image" and artifact.value is not None: - logger.debug( - f"Saving image from rank={parallel_backend.rank} to {output_filename}", - local_main_process_only=False, - ) - artifact.value.save(output_filename) - all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT)) - elif artifact.type == "video" and artifact.value is not None: - logger.debug( - f"Saving video from rank={parallel_backend.rank} to {output_filename}", - local_main_process_only=False, - ) - export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) - all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT)) - - # 3. Cleanup & log artifacts - parallel_backend.wait_for_everyone() - - memory_statistics = utils.get_memory_statistics() - logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") - - # Remove all hooks that might have been added during pipeline initialization to the models - module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "vae"] - pipeline.remove_all_hooks() - del pipeline - self._delete_components(module_names) - torch.cuda.reset_peak_memory_stats(parallel_backend.device) - - # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. - # TODO(aryan): probably should only all gather from dp mesh process group - all_artifacts = [None] * parallel_backend.world_size - torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) - all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] - - if parallel_backend.is_main_process: - tracker_key = "final" if final_validation else "validation" - artifact_log_dict = {} - - image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] - if len(image_artifacts) > 0: - artifact_log_dict["images"] = image_artifacts - video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] - if len(video_artifacts) > 0: - artifact_log_dict["videos"] = video_artifacts - parallel_backend.log({tracker_key: artifact_log_dict}, step=step) - - if self.args.push_to_hub and final_validation: - video_filenames = list(main_process_prompts_to_filenames.values()) - prompts = list(main_process_prompts_to_filenames.keys()) - utils.save_model_card( - args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts - ) - - parallel_backend.wait_for_everyone() - if not final_validation: - self._move_components_to_device() - self.transformer.train() - - def _evaluate(self) -> None: - raise NotImplementedError("Evaluation has not been implemented yet.") - - def _init_distributed(self) -> None: - # TODO: Accelerate disables native_amp for MPS. Probably need to do the same with implementation. - world_size = int(os.environ["WORLD_SIZE"]) - - # TODO(aryan): handle other backends - backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) - self.state.parallel_backend = backend_cls( - world_size=world_size, - pp_degree=self.args.pp_degree, - dp_degree=self.args.dp_degree, - dp_shards=self.args.dp_shards, - cp_degree=self.args.cp_degree, - tp_degree=self.args.tp_degree, - backend="nccl", - timeout=self.args.init_timeout, - logging_dir=self.args.logging_dir, - output_dir=self.args.output_dir, - gradient_accumulation_steps=self.args.gradient_accumulation_steps, - ) - - if self.args.seed is not None: - world_mesh = self.state.parallel_backend.get_mesh() - utils.enable_determinism(self.args.seed, world_mesh) - - def _init_logging(self) -> None: - transformers_log_level = transformers.utils.logging.set_verbosity_error - diffusers_log_level = diffusers.utils.logging.set_verbosity_error - - if self.args.verbose == 0: - if self.state.parallel_backend.is_local_main_process: - transformers_log_level = transformers.utils.logging.set_verbosity_warning - diffusers_log_level = diffusers.utils.logging.set_verbosity_warning - elif self.args.verbose == 1: - if self.state.parallel_backend.is_local_main_process: - transformers_log_level = transformers.utils.logging.set_verbosity_info - diffusers_log_level = diffusers.utils.logging.set_verbosity_info - elif self.args.verbose == 2: - if self.state.parallel_backend.is_local_main_process: - transformers_log_level = transformers.utils.logging.set_verbosity_debug - diffusers_log_level = diffusers.utils.logging.set_verbosity_debug - else: - transformers_log_level = transformers.utils.logging.set_verbosity_debug - diffusers_log_level = diffusers.utils.logging.set_verbosity_debug - - transformers_log_level() - diffusers_log_level() - - logging._set_parallel_backend(self.state.parallel_backend) - logger.info("Initialized FineTrainers") - - def _init_trackers(self) -> None: - # TODO(aryan): handle multiple trackers - trackers = [self.args.report_to] - experiment_name = self.args.tracker_name or "finetrainers-experiment" - self.state.parallel_backend.initialize_trackers( - trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir - ) - - def _init_directories_and_repositories(self) -> None: - if self.state.parallel_backend.is_main_process: - self.args.output_dir = Path(self.args.output_dir) - self.args.output_dir.mkdir(parents=True, exist_ok=True) - self.state.output_dir = Path(self.args.output_dir) - - if self.args.push_to_hub: - repo_id = self.args.hub_model_id or Path(self.args.output_dir).name - self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id - - def _init_config_options(self) -> None: - # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if self.args.allow_tf32 and torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True - - def _move_components_to_device( - self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None - ) -> None: - if device is None: - device = self.state.parallel_backend.device - if components is None: - components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae] - components = utils.get_non_null_items(components) - components = list(filter(lambda x: hasattr(x, "to"), components)) - for component in components: - component.to(device) - - def _set_components(self, components: Dict[str, Any]) -> None: - for component_name in self._all_component_names: - existing_component = getattr(self, component_name, None) - new_component = components.get(component_name, existing_component) - setattr(self, component_name, new_component) - - def _delete_components(self, component_names: Optional[List[str]] = None) -> None: - if component_names is None: - component_names = self._all_component_names - for component_name in component_names: - setattr(self, component_name, None) - utils.free_memory() - utils.synchronize_device() - - def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: - module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"] - - if not final_validation: - module_names.remove("transformer") - pipeline = self.model_specification.load_pipeline( - tokenizer=self.tokenizer, - tokenizer_2=self.tokenizer_2, - tokenizer_3=self.tokenizer_3, - text_encoder=self.text_encoder, - text_encoder_2=self.text_encoder_2, - text_encoder_3=self.text_encoder_3, - # TODO(aryan): handle unwrapping for compiled modules - # transformer=utils.unwrap_model(accelerator, self.transformer), - transformer=self.transformer, - vae=self.vae, - enable_slicing=self.args.enable_slicing, - enable_tiling=self.args.enable_tiling, - enable_model_cpu_offload=self.args.enable_model_cpu_offload, - training=True, - ) - else: - self._delete_components() - - # Load the transformer weights from the final checkpoint if performing full-finetune - transformer = None - if self.args.training_type == TrainingType.FULL_FINETUNE: - transformer = self.model_specification.load_diffusion_models()["transformer"] - - pipeline = self.model_specification.load_pipeline( - transformer=transformer, - enable_slicing=self.args.enable_slicing, - enable_tiling=self.args.enable_tiling, - enable_model_cpu_offload=self.args.enable_model_cpu_offload, - training=False, - ) - - # Load the LoRA weights if performing LoRA finetuning - if self.args.training_type == TrainingType.LORA: - pipeline.load_lora_weights(self.args.output_dir) - - components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names} - self._set_components(components) - if not self.args.enable_model_cpu_offload: - self._move_components_to_device(list(components.values())) - return pipeline - - def _prepare_data( - self, - preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor], - data_iterator, - ): - if not self.args.enable_precomputation: - if not self._are_condition_models_loaded: - logger.info( - "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs." - ) - condition_components = self.model_specification.load_condition_models() - latent_components = self.model_specification.load_latent_models() - all_components = {**condition_components, **latent_components} - self._set_components(all_components) - self._move_components_to_device(list(all_components.values())) - utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) - else: - condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))} - latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))} - - condition_iterator = preprocessor.consume( - "condition", - components=condition_components, - data_iterator=data_iterator, - generator=self.state.generator, - cache_samples=True, - ) - latent_iterator = preprocessor.consume( - "latent", - components=latent_components, - data_iterator=data_iterator, - generator=self.state.generator, - use_cached_samples=True, - drop_samples=True, - ) - - self._are_condition_models_loaded = True - else: - logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.") - - parallel_backend = self.state.parallel_backend - if parallel_backend.world_size == 1: - self._move_components_to_device([self.transformer], "cpu") - utils.free_memory() - utils.synchronize_device() - torch.cuda.reset_peak_memory_stats(parallel_backend.device) - - if self.args.precomputation_once: - consume_fn = preprocessor.consume_once - else: - consume_fn = preprocessor.consume - - # Prepare condition iterators - condition_components = self.model_specification.load_condition_models() - component_names = list(condition_components.keys()) - component_modules = list(condition_components.values()) - self._set_components(condition_components) - self._move_components_to_device(component_modules) - condition_iterator = consume_fn( - "condition", - components=condition_components, - data_iterator=data_iterator, - generator=self.state.generator, - cache_samples=True, - ) - self._delete_components(component_names) - del condition_components, component_names, component_modules - - # Prepare latent iterators - latent_components = self.model_specification.load_latent_models() - utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) - component_names = list(latent_components.keys()) - component_modules = list(latent_components.values()) - self._set_components(latent_components) - self._move_components_to_device(component_modules) - latent_iterator = consume_fn( - "latent", - components=latent_components, - data_iterator=data_iterator, - generator=self.state.generator, - use_cached_samples=True, - drop_samples=True, - ) - self._delete_components(component_names) - del latent_components, component_names, component_modules - - if parallel_backend.world_size == 1: - self._move_components_to_device([self.transformer]) - - return condition_iterator, latent_iterator - - def _get_training_info(self) -> Dict[str, Any]: - info = self.args.to_dict() - - # Removing flow matching arguments when not using flow-matching objective - diffusion_args = info.get("diffusion_arguments", {}) - scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else "" - if scheduler_name != "FlowMatchEulerDiscreteScheduler": - filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} - else: - filtered_diffusion_args = diffusion_args - - info.update({"diffusion_arguments": filtered_diffusion_args}) - return info diff --git a/finetrainers/typing.py b/finetrainers/typing.py deleted file mode 100644 index b7b3b339f252d8f47ef0ff67aa6c6733a2ccd7cf..0000000000000000000000000000000000000000 --- a/finetrainers/typing.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Union - -from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler -from transformers import CLIPTokenizer, LlamaTokenizer, LlamaTokenizerFast, T5Tokenizer, T5TokenizerFast - -from .data import ImageArtifact, VideoArtifact - - -ArtifactType = Union[ImageArtifact, VideoArtifact] -SchedulerType = Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler] -TokenizerType = Union[CLIPTokenizer, T5Tokenizer, T5TokenizerFast, LlamaTokenizer, LlamaTokenizerFast] diff --git a/finetrainers/utils/__init__.py b/finetrainers/utils/__init__.py deleted file mode 100644 index 6a4d84ec7ee75f5f4e760788d2fe8855977fc036..0000000000000000000000000000000000000000 --- a/finetrainers/utils/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -import inspect -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -from .activation_checkpoint import apply_activation_checkpointing -from .data import determine_batch_size, should_perform_precomputation -from .diffusion import ( - _enable_vae_memory_optimizations, - default_flow_shift, - get_scheduler_alphas, - get_scheduler_sigmas, - prepare_loss_weights, - prepare_sigmas, - prepare_target, - resolution_dependent_timestep_flow_shift, -) -from .file import delete_files, find_files, string_to_filename -from .hub import save_model_card -from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous -from .model import resolve_component_cls -from .state_checkpoint import PTDCheckpointManager -from .torch import ( - align_device_and_dtype, - clip_grad_norm_, - enable_determinism, - expand_tensor_dims, - get_device_info, - set_requires_grad, - synchronize_device, - unwrap_model, -) - - -def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]: - if method_name is not None: - obj = getattr(obj, method_name) - return {name for name, _ in inspect.signature(obj).parameters.items()} - - -def get_non_null_items( - x: Union[List[Any], Tuple[Any], Dict[str, Any]] -) -> Union[List[Any], Tuple[Any], Dict[str, Any]]: - if isinstance(x, dict): - return {k: v for k, v in x.items() if v is not None} - if isinstance(x, (list, tuple)): - return type(x)(v for v in x if v is not None) diff --git a/finetrainers/utils/_common.py b/finetrainers/utils/_common.py deleted file mode 100644 index c230e878d6fe715d696d8285c51f0ba073fd6b3e..0000000000000000000000000000000000000000 --- a/finetrainers/utils/_common.py +++ /dev/null @@ -1,6 +0,0 @@ -DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [ - "transformer_blocks", - "single_transformer_blocks", - "temporal_transformer_blocks", - "blocks", -] diff --git a/finetrainers/utils/activation_checkpoint.py b/finetrainers/utils/activation_checkpoint.py deleted file mode 100644 index cc4193a6cc027a771fe1fc2c3cb34595fbc336b2..0000000000000000000000000000000000000000 --- a/finetrainers/utils/activation_checkpoint.py +++ /dev/null @@ -1,71 +0,0 @@ -import collections -from enum import Enum - -import torch -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper - -from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES - - -class CheckpointType(str, Enum): - FULL = "full" - OPS = "ops" - BLOCK_SKIP = "block_skip" - - -_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, -} - - -def apply_activation_checkpointing( - module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1 -) -> torch.nn.Module: - if checkpointing_type == CheckpointType.FULL: - module = _apply_activation_checkpointing_blocks(module) - elif checkpointing_type == CheckpointType.OPS: - module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) - elif checkpointing_type == CheckpointType.BLOCK_SKIP: - module = _apply_activation_checkpointing_blocks(module, n_layer) - else: - raise ValueError( - f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" - ) - return module - - -def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module: - for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: - blocks: torch.nn.Module = getattr(module, transformer_block_name, None) - if blocks is None: - continue - for index, (layer_id, block) in enumerate(blocks.named_children()): - if n_layer is None or index % n_layer == 0: - block = checkpoint_wrapper(block, preserve_rng_state=False) - blocks.register_module(layer_id, block) - return module - - -def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module: - from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) - return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = collections.defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - - return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False) diff --git a/finetrainers/utils/data.py b/finetrainers/utils/data.py deleted file mode 100644 index ae3fcc35262f7ec85d015c468983b033d61a154c..0000000000000000000000000000000000000000 --- a/finetrainers/utils/data.py +++ /dev/null @@ -1,51 +0,0 @@ -from pathlib import Path -from typing import Any, Union - -import torch -from accelerate.logging import get_logger - -from ..constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME - - -logger = get_logger("finetrainers") - - -def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool: - if isinstance(precomputation_dir, str): - precomputation_dir = Path(precomputation_dir) - conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME - latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME - if conditions_dir.exists() and latents_dir.exists(): - num_files_conditions = len(list(conditions_dir.glob("*.pt"))) - num_files_latents = len(list(latents_dir.glob("*.pt"))) - if num_files_conditions != num_files_latents: - logger.warning( - f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})." - f"Cleaning up precomputed directories and re-running precomputation." - ) - # clean up precomputed directories - for file in conditions_dir.glob("*.pt"): - file.unlink() - for file in latents_dir.glob("*.pt"): - file.unlink() - return True - if num_files_conditions > 0: - logger.info(f"Found {num_files_conditions} precomputed conditions and latents.") - return False - logger.info("Precomputed data not found. Running precomputation.") - return True - - -def determine_batch_size(x: Any) -> int: - if isinstance(x, list): - return len(x) - if isinstance(x, torch.Tensor): - return x.size(0) - if isinstance(x, dict): - for key in x: - try: - return determine_batch_size(x[key]) - except ValueError: - pass - return 1 - raise ValueError("Could not determine batch size from input.") diff --git a/finetrainers/utils/diffusion.py b/finetrainers/utils/diffusion.py deleted file mode 100644 index 9ed3746c160b7aa1ea96fb382ccbece85db6ae42..0000000000000000000000000000000000000000 --- a/finetrainers/utils/diffusion.py +++ /dev/null @@ -1,152 +0,0 @@ -import math -from typing import Optional, Union - -import torch -from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler -from diffusers.training_utils import compute_loss_weighting_for_sd3 - - -# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47 -def resolution_dependent_timestep_flow_shift( - latents: torch.Tensor, - sigmas: torch.Tensor, - base_image_seq_len: int = 256, - max_image_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.15, -) -> torch.Tensor: - image_or_video_sequence_length = 0 - if latents.ndim == 4: - image_or_video_sequence_length = latents.shape[2] * latents.shape[3] - elif latents.ndim == 5: - image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4] - else: - raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor") - - m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) - b = base_shift - m * base_image_seq_len - mu = m * image_or_video_sequence_length + b - sigmas = default_flow_shift(latents, sigmas, shift=mu) - return sigmas - - -def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor: - sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) - return sigmas - - -def compute_density_for_timestep_sampling( - weighting_scheme: str, - batch_size: int, - logit_mean: float = None, - logit_std: float = None, - mode_scale: float = None, - device: torch.device = torch.device("cpu"), - generator: Optional[torch.Generator] = None, -) -> torch.Tensor: - r""" - Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device=device, generator=generator) - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device=device, generator=generator) - return u - - -def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: - if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): - return None - elif isinstance(scheduler, CogVideoXDDIMScheduler): - return scheduler.alphas_cumprod.clone() - else: - raise ValueError(f"Unsupported scheduler type {type(scheduler)}") - - -def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: - if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): - return scheduler.sigmas.clone() - elif isinstance(scheduler, CogVideoXDDIMScheduler): - return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps) - else: - raise ValueError(f"Unsupported scheduler type {type(scheduler)}") - - -def prepare_sigmas( - scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], - sigmas: torch.Tensor, - batch_size: int, - num_train_timesteps: int, - flow_weighting_scheme: str = "none", - flow_logit_mean: float = 0.0, - flow_logit_std: float = 1.0, - flow_mode_scale: float = 1.29, - device: torch.device = torch.device("cpu"), - generator: Optional[torch.Generator] = None, -) -> torch.Tensor: - if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): - weights = compute_density_for_timestep_sampling( - weighting_scheme=flow_weighting_scheme, - batch_size=batch_size, - logit_mean=flow_logit_mean, - logit_std=flow_logit_std, - mode_scale=flow_mode_scale, - device=device, - generator=generator, - ) - indices = (weights * num_train_timesteps).long() - elif isinstance(scheduler, CogVideoXDDIMScheduler): - # TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes. - weights = torch.rand(size=(batch_size,), device=device, generator=generator) - indices = (weights * num_train_timesteps).long() - else: - raise ValueError(f"Unsupported scheduler type {type(scheduler)}") - - return sigmas[indices] - - -def prepare_loss_weights( - scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], - alphas: Optional[torch.Tensor] = None, - sigmas: Optional[torch.Tensor] = None, - flow_weighting_scheme: str = "none", -) -> torch.Tensor: - if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): - return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme) - elif isinstance(scheduler, CogVideoXDDIMScheduler): - # SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas). - # TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results. - return 1 / (1 - alphas) - else: - raise ValueError(f"Unsupported scheduler type {type(scheduler)}") - - -def prepare_target( - scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], - noise: torch.Tensor, - latents: torch.Tensor, -) -> torch.Tensor: - if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): - target = noise - latents - elif isinstance(scheduler, CogVideoXDDIMScheduler): - target = latents - else: - raise ValueError(f"Unsupported scheduler type {type(scheduler)}") - - return target - - -def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False): - if hasattr(vae, "enable_slicing") and enable_slicing: - vae.enable_slicing() - if hasattr(vae, "enable_tiling") and enable_tiling: - vae.enable_tiling() diff --git a/finetrainers/utils/file.py b/finetrainers/utils/file.py deleted file mode 100644 index ba01213e758685aaa339f3ecad12c312a540dd9e..0000000000000000000000000000000000000000 --- a/finetrainers/utils/file.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import shutil -from pathlib import Path -from typing import List, Union - -from ..logging import get_logger - - -logger = get_logger() - - -def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]: - if not isinstance(dir, Path): - dir = Path(dir) - if not dir.exists(): - return [] - checkpoints = os.listdir(dir.as_posix()) - checkpoints = [c for c in checkpoints if c.startswith(prefix)] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - return checkpoints - - -def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: - if not isinstance(dirs, list): - dirs = [dirs] - dirs = [Path(d) if isinstance(d, str) else d for d in dirs] - logger.debug(f"Deleting files: {dirs}") - for dir in dirs: - if not dir.exists(): - continue - shutil.rmtree(dir, ignore_errors=True) - - -def string_to_filename(s: str) -> str: - return ( - s.replace(" ", "-") - .replace("/", "-") - .replace(":", "-") - .replace(".", "-") - .replace(",", "-") - .replace(";", "-") - .replace("!", "-") - .replace("?", "-") - ) diff --git a/finetrainers/utils/hub.py b/finetrainers/utils/hub.py deleted file mode 100644 index ea1a16eb42cbb1f2848376440817a3e1680ce61c..0000000000000000000000000000000000000000 --- a/finetrainers/utils/hub.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -from typing import List, Union - -import numpy as np -import wandb -from diffusers.utils import export_to_video -from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from PIL import Image - - -def save_model_card( - args, - repo_id: str, - videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]], - validation_prompts: List[str], - fps: int = 30, -) -> None: - widget_dict = [] - output_dir = str(args.output_dir) - if videos is not None and len(videos) > 0: - for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)): - if not isinstance(video, str): - export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps) - widget_dict.append( - { - "text": validation_prompt if validation_prompt else " ", - "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"}, - } - ) - - model_description = f""" -# LoRA Finetune - - - -## Model description - -This is a lora finetune of model: `{args.pretrained_model_name_or_path}`. - -The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers). - -## Download model - -[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. - -## Usage - -Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. - -```py -TODO -``` - -For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. -""" - if wandb.run.url: - model_description += f""" -Find out the wandb run URL and training configurations [here]({wandb.run.url}). -""" - - model_card = load_or_create_model_card( - repo_id_or_path=repo_id, - from_training=True, - base_model=args.pretrained_model_name_or_path, - model_description=model_description, - widget=widget_dict, - ) - tags = [ - "text-to-video", - "diffusers-training", - "diffusers", - "lora", - "template:sd-lora", - ] - - model_card = populate_model_card(model_card, tags=tags) - model_card.save(os.path.join(args.output_dir, "README.md")) diff --git a/finetrainers/utils/import_utils.py b/finetrainers/utils/import_utils.py deleted file mode 100644 index 56c19db6e4f032296bca103e38155da1ded1c074..0000000000000000000000000000000000000000 --- a/finetrainers/utils/import_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import importlib - -import importlib_metadata - -from ..logging import get_logger - - -logger = get_logger() - - -_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None -try: - _bitsandbytes_version = importlib_metadata.version("bitsandbytes") - logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") -except importlib_metadata.PackageNotFoundError: - _bitsandbytes_available = False - - -def is_bitsandbytes_available(): - return _bitsandbytes_available diff --git a/finetrainers/utils/memory.py b/finetrainers/utils/memory.py deleted file mode 100644 index d7616b190ebe474484e4aaf438b9a80eabf6ab66..0000000000000000000000000000000000000000 --- a/finetrainers/utils/memory.py +++ /dev/null @@ -1,58 +0,0 @@ -import gc -from typing import Any, Dict, Union - -import torch -from accelerate.logging import get_logger - - -logger = get_logger("finetrainers") - - -def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: - memory_allocated = None - memory_reserved = None - max_memory_allocated = None - max_memory_reserved = None - - if torch.cuda.is_available(): - device = torch.cuda.current_device() - memory_allocated = torch.cuda.memory_allocated(device) - memory_reserved = torch.cuda.memory_reserved(device) - max_memory_allocated = torch.cuda.max_memory_allocated(device) - max_memory_reserved = torch.cuda.max_memory_reserved(device) - - elif torch.backends.mps.is_available(): - memory_allocated = torch.mps.current_allocated_memory() - - else: - logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") - - return { - "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), - "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), - "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), - "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), - } - - -def bytes_to_gigabytes(x: int) -> float: - if x is not None: - return x / 1024**3 - - -def free_memory() -> None: - if torch.cuda.is_available(): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - # TODO(aryan): handle non-cuda devices - - -def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if isinstance(x, torch.Tensor): - return x.contiguous() - elif isinstance(x, dict): - return {k: make_contiguous(v) for k, v in x.items()} - else: - return x diff --git a/finetrainers/utils/model.py b/finetrainers/utils/model.py deleted file mode 100644 index 4427f97d25ed44b2d9832cf456b082f65d66c2a8..0000000000000000000000000000000000000000 --- a/finetrainers/utils/model.py +++ /dev/null @@ -1,32 +0,0 @@ -import importlib -import json -import os -from typing import Optional - -from huggingface_hub import hf_hub_download - - -def resolve_component_cls( - pretrained_model_name_or_path: str, - component_name: str, - filename: str = "model_index.json", - revision: Optional[str] = None, - cache_dir: Optional[str] = None, -): - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.exists(str(pretrained_model_name_or_path)) and os.path.isdir(pretrained_model_name_or_path): - index_path = os.path.join(pretrained_model_name_or_path, filename) - else: - index_path = hf_hub_download( - repo_id=pretrained_model_name_or_path, filename=filename, revision=revision, cache_dir=cache_dir - ) - - with open(index_path, "r") as f: - model_index_dict = json.load(f) - - if component_name not in model_index_dict: - raise ValueError(f"No {component_name} found in the model index dict.") - - cls_config = model_index_dict[component_name] - library = importlib.import_module(cls_config[0]) - return getattr(library, cls_config[1]) diff --git a/finetrainers/utils/state_checkpoint.py b/finetrainers/utils/state_checkpoint.py deleted file mode 100644 index ab0e0b9b5f6214ba56d0c308714e29b9e11f4d8a..0000000000000000000000000000000000000000 --- a/finetrainers/utils/state_checkpoint.py +++ /dev/null @@ -1,203 +0,0 @@ -import functools -import pathlib -import shutil -import time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union - -import torch -import torch.distributed.checkpoint -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - set_model_state_dict, -) -from torch.distributed.checkpoint.stateful import Stateful - -from ..logging import get_logger - - -if TYPE_CHECKING: - from .. import optimizer - - -logger = get_logger() - - -class ModelWrapper(Stateful): - def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: - self.model = [model] if isinstance(model, torch.nn.Module) else model - - def state_dict(self) -> Dict[str, Any]: - return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - func = functools.partial( - set_model_state_dict, - model_state_dict=state_dict, - options=StateDictOptions(strict=False), - ) - list(map(func, self.model)) - - -class PTDCheckpointManager: - def __init__( - self, - dataloader: torch.utils.data.DataLoader, - model_parts: List[torch.nn.Module], - optimizers: "optimizer.OptimizerWrapper", - schedulers: "optimizer.SchedulerWrapper", - states: Dict[str, Any], - checkpointing_steps: int, - checkpointing_limit: int, - output_dir: str, - enable: bool = True, - _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, - _prefix: str = "finetrainers_step", - ) -> None: - self.states = states - self.states.update( - { - "model": ModelWrapper(model_parts), - "optimizer": optimizers, - "dataloader": dataloader, - } - ) - self.states.update(schedulers.get_lr_scheduler_state()) - - self.checkpointing_steps = checkpointing_steps - self.checkpointing_limit = checkpointing_limit - self.output_dir = pathlib.Path(output_dir) - self.enable = enable - self._callback_fn = _callback_fn - self._prefix = _prefix - - logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") - - def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: - if not self._should_checkpoint(step, force): - return None - - checkpoint_dir = self._get_checkpoint_dir(step) - begin_time = time.monotonic() - torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) - end_time = time.monotonic() - logger.info( - f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" - ) - self._purge_stale_checkpoints() - - state_dicts = [ - gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) - for model in self.states["model"].model - ] - if self._callback_fn is not None: - list(map(self._callback_fn, state_dicts)) - - return checkpoint_dir.as_posix() - - def load(self, step: int = -1) -> bool: - if not self.enable: - return False - if not self.output_dir.exists(): - return False - if step != -1 and not self._get_checkpoint_dir(step).exists(): - return False - - if step == -1: - latest_checkpoint_dir = self._find_latest_checkpoint_dir() - if latest_checkpoint_dir is None: - return False - step = int(latest_checkpoint_dir.name.split("_")[-1]) - - checkpoint_dir = self._get_checkpoint_dir(step) - logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") - - # For step 0, optimizers/schedulers are not available as they are created during training after first step - states = {"model": self.states["model"]} if step == 0 else self.states - - # See bug: https://github.com/pytorch/pytorch/pull/138575 - original_stateful_states = {k: v for k, v in states.items() if isinstance(v, Stateful)} - begin_time = time.monotonic() - torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) - end_time = time.monotonic() - logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") - - # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() - states.update(original_stateful_states) - - return True - - def _should_checkpoint(self, step: int, force: bool) -> bool: - if not self.enable: - return False - - if not force: - if step % self.checkpointing_steps != 0: - return False - - return True - - def _get_checkpoint_dir(self, step: int) -> pathlib.Path: - return self.output_dir / f"{self._prefix}_{step}" - - def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]: - checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) - return checkpoints[-1] if len(checkpoints) > 0 else None - - def _purge_stale_checkpoints(self) -> None: - if self.checkpointing_limit is None or self.checkpointing_limit <= 0: - return - checkpoints = sorted( - self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True - ) - for checkpoint in checkpoints[self.checkpointing_limit :]: - logger.info(f"Deleting stale checkpoint: {checkpoint}") - shutil.rmtree(checkpoint, ignore_errors=True) - - -def gather_state_dict_on_cpu_rank0( - model, device: Optional[torch.device] = None, *, is_main_process: bool -) -> Dict[str, Any]: - cpu_state_dict = {} - sharded_sd = model.state_dict() - for param_name, param in sharded_sd.items(): - if param.is_cpu: - # Move back to device if offloaded to CPU - param = param.to(device) - if hasattr(param, "_local_tensor"): - # Gather DTensor - param = param.full_tensor() - if is_main_process: - cpu_state_dict[param_name] = param.cpu() - torch.distributed.barrier() - return cpu_state_dict - - -# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict -# def dcp_to_torch_save( -# dcp_checkpoint_dir: Union[str, os.PathLike], -# torch_save_path: Union[str, os.PathLike], -# callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, -# ): -# """ -# Given a directory containing a DCP checkpoint, this function will convert it into a -# Torch save file. - -# Args: -# dcp_checkpoint_dir: Directory containing the DCP checkpoint. -# torch_save_path: Filename to store the converted Torch save file. -# callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. - -# .. warning:: -# To avoid OOM, it's recommended to only run this function on a single rank. -# """ -# state_dict = {} -# _load_state_dict( -# state_dict, -# storage_reader=FileSystemReader(dcp_checkpoint_dir), -# planner=_EmptyStateDictLoadPlanner(), -# no_dist=True, -# ) -# if callback_fn is not None: -# state_dict = callback_fn(state_dict) -# torch.save(state_dict, torch_save_path) diff --git a/finetrainers/utils/torch.py b/finetrainers/utils/torch.py deleted file mode 100644 index db434d3e361ea03e6c06811c98f08594c4ab773d..0000000000000000000000000000000000000000 --- a/finetrainers/utils/torch.py +++ /dev/null @@ -1,338 +0,0 @@ -import math -import os -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.backends -import torch.distributed as dist -import torch.distributed.tensor -from accelerate import Accelerator -from diffusers.utils.torch_utils import is_compiled_module - -from ..logging import get_logger - - -logger = get_logger() - -_STRING_TO_DTYPE = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, -} - -_DTYPE_TO_STRING = {v: k for k, v in _STRING_TO_DTYPE.items()} - -_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False - - -def align_device_and_dtype( - x: Union[torch.Tensor, Dict[str, torch.Tensor]], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - if isinstance(x, torch.Tensor): - if device is not None: - x = x.to(device) - if dtype is not None: - x = x.to(dtype) - elif isinstance(x, dict): - if device is not None: - x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} - if dtype is not None: - x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} - return x - - -def _clip_grad_norm_while_handling_failing_dtensor_cases( - parameters: Union[torch.Tensor, List[torch.Tensor]], - max_norm: float, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: Optional[bool] = None, - pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, -) -> Optional[torch.Tensor]: - global _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES - - if not _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES: - try: - return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach, pp_mesh) - except NotImplementedError as e: - if "DTensor does not support cross-mesh operation" in str(e): - # https://github.com/pytorch/pytorch/issues/134212 - logger.warning( - "DTensor does not support cross-mesh operation. If you haven't fully tensor-parallelized your " - "model, while combining other parallelisms such as FSDP, it could be the reason for this error. " - "Gradient clipping will be skipped and gradient norm will not be logged." - ) - except Exception as e: - logger.warning( - f"An error occurred while clipping gradients: {e}. Gradient clipping will be skipped and gradient " - f"norm will not be logged." - ) - _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = True - return None - - -# Copied from https://github.com/pytorch/torchtitan/blob/4a169701555ab9bd6ca3769f9650ae3386b84c6e/torchtitan/utils.py#L362 -@torch.no_grad() -def clip_grad_norm_( - parameters: Union[torch.Tensor, List[torch.Tensor]], - max_norm: float, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: Optional[bool] = None, - pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, -) -> torch.Tensor: - r""" - Clip the gradient norm of parameters. - - Gradient norm clipping requires computing the gradient norm over the entire model. - `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. - We need to manually reduce the gradient norm across PP stages. - See https://github.com/pytorch/torchtitan/issues/596 for details. - - Args: - parameters (`torch.Tensor` or `List[torch.Tensor]`): - Tensors that will have gradients normalized. - max_norm (`float`): - Maximum norm of the gradients after clipping. - norm_type (`float`, defaults to `2.0`): - Type of p-norm to use. Can be `inf` for infinity norm. - error_if_nonfinite (`bool`, defaults to `False`): - If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`. - foreach (`bool`, defaults to `None`): - Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors - and silently fall back to the slow implementation for other device types. - pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`): - Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages. - - Returns: - `torch.Tensor`: - Total norm of the gradients - """ - grads = [p.grad for p in parameters if p.grad is not None] - - # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm` - # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach) - total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) - - # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. - # We can simply reduce the DTensor to get the total norm in this tensor's process group - # and then convert it to a local tensor. - # It has two purposes: - # 1. to make sure the total norm is computed correctly when PP is used (see below) - # 2. to return a reduced total_norm tensor whose .item() would return the correct value - if isinstance(total_norm, torch.distributed.tensor.DTensor): - # Will reach here if any non-PP parallelism is used. - # If only using PP, total_norm will be a local tensor. - total_norm = total_norm.full_tensor() - - if pp_mesh is not None: - if math.isinf(norm_type): - dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) - else: - total_norm **= norm_type - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type - - _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) - return total_norm - - -def enable_determinism( - seed: int, - world_mesh: Optional[torch.distributed.DeviceMesh] = None, - deterministic: bool = False, -) -> None: - r""" - For all ranks within the same DTensor SPMD group, the same seed will be set. - For PP groups, different seeds will be set. - """ - if deterministic: - logger.info("Deterministic algorithms are enabled (expect performance degradation).") - torch.use_deterministic_algorithms(True) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - - if not world_mesh: - if seed is not None: - torch.manual_seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed % 2**32) - logger.debug(f"Single-process job using seed: {seed}") - return - - # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, - # and choose a unique seed for each rank on the PP mesh. - if torch.distributed.distributed_c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names: - pp_mesh = world_mesh["pp"] - seed += pp_mesh.get_local_rank() - seed %= 2**64 - - info = { - "pp_rank": pp_mesh.get_local_rank(), - "global_rank": torch.distributed.distributed_c10d.get_rank(), - "seed": seed, - } - logger.debug(f"Enabling determinism: {info}") - spmd_mesh_dims = list(filter(lambda name: name != "pp", world_mesh.mesh_dim_names)) - spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None - else: - spmd_mesh = world_mesh - info = {"global_rank": torch.distributed.distributed_c10d.get_rank(), "seed": seed} - logger.debug(f"Enabling determinism: {info}") - - # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency - torch.manual_seed(seed) - # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] - os.environ["PYTHONHASHSEED"] = str(seed % 2**32) - - # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. - # IF PP is also used, this seed is unique per PP rank. - if spmd_mesh and spmd_mesh.get_coordinate() is not None: - torch.distributed.tensor._random.manual_seed(seed, spmd_mesh) - - -def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor: - assert len(tensor.shape) <= ndim - return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape))) - - -def get_device_info(): - from torch._utils import _get_available_device_type, _get_device_module - - device_type = _get_available_device_type() - if device_type is None: - device_type = "cuda" - device_module = _get_device_module(device_type) - return device_type, device_module - - -def get_dtype_from_string(dtype: str): - return _STRING_TO_DTYPE[dtype] - - -def get_string_from_dtype(dtype: torch.dtype): - return _DTYPE_TO_STRING[dtype] - - -def set_requires_grad(models: Union[torch.nn.Module, List[torch.nn.Module]], value: bool) -> None: - if isinstance(models, torch.nn.Module): - models = [models] - for model in models: - if model is not None: - model.requires_grad_(value) - - -def synchronize_device() -> None: - if torch.cuda.is_available(): - torch.cuda.synchronize() - elif torch.backends.mps.is_available(): - torch.mps.synchronize() - - -def unwrap_model(accelerator: Accelerator, model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - -# TODO(aryan): remove everything below this after next torch release -def _get_total_norm( - tensors: Union[torch.Tensor, List[torch.Tensor]], - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: Optional[bool] = None, -) -> torch.Tensor: - if isinstance(tensors, torch.Tensor): - tensors = [tensors] - else: - tensors = list(tensors) - norm_type = float(norm_type) - if len(tensors) == 0: - return torch.tensor(0.0) - first_device = tensors[0].device - grouped_tensors: dict[ - tuple[torch.device, torch.dtype], tuple[list[list[torch.Tensor]], list[int]] - ] = _group_tensors_by_device_and_dtype( - [tensors] # type: ignore[list-item] - ) # type: ignore[assignment] - - norms: List[torch.Tensor] = [] - for (device, _), ([device_tensors], _) in grouped_tensors.items(): - if (foreach is None and _has_foreach_support(device_tensors, device)) or ( - foreach and _device_has_foreach_support(device) - ): - norms.extend(torch._foreach_norm(device_tensors, norm_type)) - elif foreach: - raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") - else: - norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors]) - - total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) - - if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): - raise RuntimeError( - f"The total norm of order {norm_type} for gradients from " - "`parameters` is non-finite, so it cannot be clipped. To disable " - "this error and scale the gradients by the non-finite norm anyway, " - "set `error_if_nonfinite=False`" - ) - return total_norm - - -@torch.no_grad() -def _clip_grads_with_norm_( - parameters: Union[torch.Tensor, List[torch.Tensor]], - max_norm: float, - total_norm: torch.Tensor, - foreach: Optional[bool] = None, -) -> None: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - grads = [p.grad for p in parameters if p.grad is not None] - max_norm = float(max_norm) - if len(grads) == 0: - return - grouped_grads: dict[ - Tuple[torch.device, torch.dtype], Tuple[List[List[torch.Tensor]], List[int]] - ] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] - - clip_coef = max_norm / (total_norm + 1e-6) - # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so - # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization - # when the gradients do not reside in CPU memory. - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - for (device, _), ([device_grads], _) in grouped_grads.items(): - if (foreach is None and _has_foreach_support(device_grads, device)) or ( - foreach and _device_has_foreach_support(device) - ): - torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) - elif foreach: - raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") - else: - clip_coef_clamped_device = clip_coef_clamped.to(device) - for g in device_grads: - g.mul_(clip_coef_clamped_device) - - -def _get_foreach_kernels_supported_devices() -> list[str]: - r"""Return the device type list that supports foreach kernels.""" - return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] - - -@torch.no_grad() -def _group_tensors_by_device_and_dtype( - tensorlistlist: List[List[Optional[torch.Tensor]]], - with_indices: bool = False, -) -> dict[tuple[torch.device, torch.dtype], tuple[List[List[Optional[torch.Tensor]]], List[int]]]: - return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) - - -def _device_has_foreach_support(device: torch.device) -> bool: - return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() - - -def _has_foreach_support(tensors: List[torch.Tensor], device: torch.device) -> bool: - return _device_has_foreach_support(device) and all(t is None or type(t) in [torch.Tensor] for t in tensors) diff --git a/requirements.txt b/requirements.txt index dd0c7ca5baf92af5c807fe9dd98a0def0fdfaf35..097f78a91012eb1eae0af96c896f3302965f52f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,8 @@ peft>=0.12.0 #eva-decord==0.6.1 decord +finetrainers @ git+https://github.com/a-r-r-o-w/finetrainers.git@4c6844fabf496fb622b6333f57e8ebecee6f2780 + wandb pandas sentencepiece>=0.2.0 diff --git a/requirements_without_flash_attention.txt b/requirements_without_flash_attention.txt index b059dcafbe251224ee8cb8eb1a06523f29a2397e..c5f85fdb8445b108301918293251313436a4506b 100644 --- a/requirements_without_flash_attention.txt +++ b/requirements_without_flash_attention.txt @@ -25,6 +25,8 @@ peft>=0.12.0 eva-decord==0.6.1 # decord +finetrainers @ git+https://github.com/a-r-r-o-w/finetrainers.git@4c6844fabf496fb622b6333f57e8ebecee6f2780 + wandb pandas sentencepiece>=0.2.0