diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b827644ebb071d5cd31ab1c926fb9d5dfdb4e316 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.DS_Store +.venv +.data +__pycache__ +*.mp3 +*.mp4 +*.zip +training_service.log diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6f823da2a3444f75fec930d217c78c2e89c7e9b1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,44 @@ +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 + +# Prevent interactive prompts during build +ARG DEBIAN_FRONTEND=noninteractive + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + python3.10 \ + python3-pip \ + python3-dev \ + git \ + ffmpeg \ + libsm6 \ + libxext6 \ + libgl1-mesa-glx \ + libglib2.0-0 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Create app directory +WORKDIR /app + +# Install Python dependencies first for better caching +COPY requirements.txt . +RUN pip3 install --no-cache-dir -r requirements.txt + +# actually we found a way to put flash attention inside the requirements.txt +# so we are good, we don't need this anymore: +# RUN pip3 install --no-cache-dir -r requirements_without_flash_attention.txt +# RUN pip3 install wheel setuptools flash-attn --no-build-isolation --no-cache-dir + +# Copy application files +COPY . . + +# Expose Gradio port +EXPOSE 7860 + +# Run the application +CMD ["python3", "app.py"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f8fe170e99a403e4700004db07b0277646fa4c96 --- /dev/null +++ b/README.md @@ -0,0 +1,96 @@ +--- +title: Video Model Studio +emoji: 🎥 +colorFrom: gray +colorTo: gray +sdk: gradio +sdk_version: 5.15.0 +app_file: app.py +pinned: true +license: apache-2.0 +short_description: All-in-one tool for AI video training +--- + +# 🎥 Video Model Studio (VMS) + +## Presentation + +VMS is an all-in-one tool to train LoRA models for various open-source AI video models: + +- Data collection from various sources +- Splitting videos into short single camera shots +- Automatic captioning +- Training HunyuanVideo or LTX-Video + +## Similar projects + +I wasn't aware of it when I started this project, +but there is also this: https://github.com/alisson-anjos/diffusion-pipe-ui + +## Installation + +VMS is built on top of Finetrainers and Gradio, and designed to run as a Hugging Face Space (but you can deploy it elsewhere if you want to). + +### Full installation at Hugging Face + +Easy peasy: create a Space (make sure to use the `Gradio` type/template), and push the repo. No Docker needed! + +### Dev mode on Hugging Face + +Enable dev mode in the space, then open VSCode in local or remote and run: + +``` +pip install -r requirements.txt +``` + +As this is not automatic, then click on "Restart" in the space dev mode UI widget. + +### Full installation somewhere else + +I haven't tested it, but you can try to provided Dockerfile + +### Full installation in local + +the full installation requires: +- Linux +- CUDA 12 +- Python 3.10 + +This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel (python bindings for a native library) + +```bash +./setup.sh +``` + +### Degraded installation in local + +If you cannot meet the requirements, you can: + +- solution 1: fix requirements.txt to use another prebuilt wheel +- solution 2: manually build/install flash attention +- solution 3: don't use clip captioning + +Here is how to do solution 3: +```bash +./setup_no_captions.sh +``` + +## Run + +### Running the Gradio app + +Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`) + +```bash +python app.py +``` + +### Running locally + +See above remarks about the environment variable. + +By default `run.sh` will store stuff in `.data/` (located inside the current working directory): + +```bash +./run.sh +``` \ No newline at end of file diff --git a/accelerate_configs/compiled_1.yaml b/accelerate_configs/compiled_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a7660e0dc640b3cd8381bc793ed78b16c79d9c7 --- /dev/null +++ b/accelerate_configs/compiled_1.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +dynamo_config: + dynamo_backend: INDUCTOR + dynamo_mode: max-autotune + dynamo_use_dynamic: true + dynamo_use_fullgraph: false +enable_cpu_affinity: false +gpu_ids: '3' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/accelerate_configs/deepspeed.yaml b/accelerate_configs/deepspeed.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62db0b4214e1faaac734f76086d6c6f7b6d3810b --- /dev/null +++ b/accelerate_configs/deepspeed.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/accelerate_configs/uncompiled_1.yaml b/accelerate_configs/uncompiled_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..348c1cae86a65ab605628fb39d8bc97269a11205 --- /dev/null +++ b/accelerate_configs/uncompiled_1.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: '3' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/accelerate_configs/uncompiled_2.yaml b/accelerate_configs/uncompiled_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..830b6e0daa8b12c74494c20f142da2b4a78d055e --- /dev/null +++ b/accelerate_configs/uncompiled_2.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: 0,1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/accelerate_configs/uncompiled_8.yaml b/accelerate_configs/uncompiled_8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee7f50c287c77246fd0d2042893378abd12a4943 --- /dev/null +++ b/accelerate_configs/uncompiled_8.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b42e5985a788a288527b3bf0e730537696aac1f8 --- /dev/null +++ b/app.py @@ -0,0 +1,1270 @@ +import platform +import subprocess + +#import sys +#print("python = ", sys.version) + +# can be "Linux", "Darwin" +if platform.system() == "Linux": + # for some reason it says "pip not found" + # and also "pip3 not found" + # subprocess.run( + # "pip install flash-attn --no-build-isolation", + # + # # hmm... this should be False, since we are in a CUDA environment, no? + # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, + # + # shell=True, + # ) + pass + +import gradio as gr +from pathlib import Path +import logging +import mimetypes +import shutil +import os +import traceback +import asyncio +import tempfile +import zipfile +from typing import Any, Optional, Dict, List, Union, Tuple +from typing import AsyncGenerator +from training_service import TrainingService +from captioning_service import CaptioningService +from splitting_service import SplittingService +from import_service import ImportService +from config import ( + STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, + TRAINING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, + DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS +) +from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time +from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset +from training_log_parser import TrainingLogParser + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +httpx_logger = logging.getLogger('httpx') +httpx_logger.setLevel(logging.WARN) + + +class VideoTrainerUI: + def __init__(self): + self.trainer = TrainingService() + self.splitter = SplittingService() + self.importer = ImportService() + self.captioner = CaptioningService() + self._should_stop_captioning = False + self.log_parser = TrainingLogParser() + + def update_training_ui(self, training_state: Dict[str, Any]): + """Update UI components based on training state""" + updates = {} + + # Update status box with high-level information + status_text = [] + if training_state["status"] != "idle": + status_text.extend([ + f"Status: {training_state['status']}", + f"Progress: {training_state['progress']}", + f"Step: {training_state['current_step']}/{training_state['total_steps']}", + + # Epoch information + # there is an issue with how epoch is reported because we display: + # Progress: 96.9%, Step: 872/900, Epoch: 12/50 + # we should probably just show the steps + #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}", + + f"Time elapsed: {training_state['elapsed']}", + f"Estimated remaining: {training_state['remaining']}", + "", + f"Current loss: {training_state['step_loss']}", + f"Learning rate: {training_state['learning_rate']}", + f"Gradient norm: {training_state['grad_norm']}", + f"Memory usage: {training_state['memory']}" + ]) + + if training_state["error_message"]: + status_text.append(f"\nError: {training_state['error_message']}") + + updates["status_box"] = "\n".join(status_text) + + # Update button states + updates["start_btn"] = gr.Button( + "Start training", + interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]), + variant="primary" if training_state["status"] == "idle" else "secondary" + ) + + updates["stop_btn"] = gr.Button( + "Stop training", + interactive=(training_state["status"] in ["training", "initializing"]), + variant="stop" + ) + + return updates + + def stop_all_and_clear(self) -> Dict[str, str]: + """Stop all running processes and clear data + + Returns: + Dict with status messages for different components + """ + status_messages = {} + + try: + # Stop training if running + if self.trainer.is_training_running(): + training_result = self.trainer.stop_training() + status_messages["training"] = training_result["status"] + + # Stop captioning if running + if self.captioner: + self.captioner.stop_captioning() + #self.captioner.close() + #self.captioner = None + status_messages["captioning"] = "Captioning stopped" + + # Stop scene detection if running + if self.splitter.is_processing(): + self.splitter.processing = False + status_messages["splitting"] = "Scene detection stopped" + + # Clear all data directories + for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH, + MODEL_PATH, OUTPUT_PATH]: + if path.exists(): + try: + shutil.rmtree(path) + path.mkdir(parents=True, exist_ok=True) + except Exception as e: + status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}" + else: + status_messages[f"clear_{path.name}"] = f"Cleared {path.name}" + + # Reset any persistent state + self._should_stop_captioning = True + self.splitter.processing = False + + return { + "status": "All processes stopped and data cleared", + "details": status_messages + } + + except Exception as e: + return { + "status": f"Error during cleanup: {str(e)}", + "details": status_messages + } + + def update_titles(self) -> Tuple[Any]: + """Update all dynamic titles with current counts + + Returns: + Dict of Gradio updates + """ + # Count files for splitting + split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH) + split_title = format_media_title( + "split", split_videos, 0, split_size + ) + + # Count files for captioning + caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH) + caption_title = format_media_title( + "caption", caption_videos, caption_images, caption_size + ) + + # Count files for training + train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH) + train_title = format_media_title( + "train", train_videos, train_images, train_size + ) + + return ( + gr.Markdown(value=split_title), + gr.Markdown(value=caption_title), + gr.Markdown(value=f"{train_title} available for training") + ) + + def copy_files_to_training_dir(self, prompt_prefix: str): + """Run auto-captioning process""" + + # Initialize captioner if not already done + self._should_stop_captioning = False + + try: + copy_files_to_training_dir(prompt_prefix) + + except Exception as e: + traceback.print_exc() + raise gr.Error(f"Error copying assets to training dir: {str(e)}") + + async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]: + """Run auto-captioning process""" + try: + # Initialize captioner if not already done + self._should_stop_captioning = False + + async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix): + # Yield UI update + yield gr.update( + value=rows, + headers=["name", "status"] + ) + + # Final update after completion + yield gr.update( + value=self.list_training_files_to_caption(), + headers=["name", "status"] + ) + + except Exception as e: + yield gr.update( + value=[[str(e), "error"]], + headers=["name", "status"] + ) + + def list_training_files_to_caption(self) -> List[List[str]]: + """List all clips and images - both pending and captioned""" + files = [] + already_listed: Dict[str, bool] = {} + + # Check files in STAGING_PATH + for file in STAGING_PATH.glob("*.*"): + if is_video_file(file) or is_image_file(file): + txt_file = file.with_suffix('.txt') + status = "captioned" if txt_file.exists() else "no caption" + file_type = "video" if is_video_file(file) else "image" + files.append([file.name, f"{status} ({file_type})", str(file)]) + already_listed[str(file.name)] = True + + # Check files in TRAINING_VIDEOS_PATH + for file in TRAINING_VIDEOS_PATH.glob("*.*"): + if not str(file.name) in already_listed: + if is_video_file(file) or is_image_file(file): + txt_file = file.with_suffix('.txt') + if txt_file.exists(): + file_type = "video" if is_video_file(file) else "image" + files.append([file.name, f"captioned ({file_type})", str(file)]) + + # Sort by filename + files.sort(key=lambda x: x[0]) + + # Only return name and status columns for display + return [[file[0], file[1]] for file in files] + + def update_training_buttons(self, training_state: Dict[str, Any]) -> Dict: + """Update training control buttons based on state""" + is_training = training_state["status"] in ["training", "initializing"] + is_paused = training_state["status"] == "paused" + is_completed = training_state["status"] in ["completed", "error", "stopped"] + + return { + start_btn: gr.Button( + interactive=not is_training and not is_paused, + variant="primary" if not is_training else "secondary", + ), + stop_btn: gr.Button( + interactive=is_training or is_paused, + variant="stop", + ), + pause_resume_btn: gr.Button( + value="Resume Training" if is_paused else "Pause Training", + interactive=(is_training or is_paused) and not is_completed, + variant="secondary", + ) + } + + def handle_training_complete(self): + """Handle training completion""" + # Reset button states + return self.update_training_buttons({ + "status": "completed", + "progress": "100%", + "current_step": 0, + "total_steps": 0 + }) + + def handle_pause_resume(self): + status = self.trainer.get_status() + if status["state"] == "paused": + result = self.trainer.resume_training() + new_state = {"status": "training"} + else: + result = self.trainer.pause_training() + new_state = {"status": "paused"} + return ( + *result, + *self.update_training_buttons(new_state).values() + ) + + + def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Handle selection of both video clips and images""" + try: + if not evt: + return [ + gr.Image( + interactive=False, + visible=False + ), + gr.Video( + interactive=False, + visible=False + ), + gr.Textbox( + visible=False + ), + "No file selected" + ] + + file_name = evt.value + if not file_name: + return [ + gr.Image( + interactive=False, + visible=False + ), + gr.Video( + interactive=False, + visible=False + ), + gr.Textbox( + visible=False + ), + "No file selected" + ] + + # Check both possible locations for the file + possible_paths = [ + STAGING_PATH / file_name, + + # note: we use to look into this dir for already-captioned clips, + # but we don't do this anymore + #TRAINING_VIDEOS_PATH / file_name + ] + + # Find the first existing file path + file_path = None + for path in possible_paths: + if path.exists(): + file_path = path + break + + if not file_path: + return [ + gr.Image( + interactive=False, + visible=False + ), + gr.Video( + interactive=False, + visible=False + ), + gr.Textbox( + visible=False + ), + f"File not found: {file_name}" + ] + + txt_path = file_path.with_suffix('.txt') + caption = txt_path.read_text() if txt_path.exists() else "" + + # Handle video files + if is_video_file(file_path): + return [ + gr.Image( + interactive=False, + visible=False + ), + gr.Video( + label="Video Preview", + interactive=False, + visible=True, + value=str(file_path) + ), + gr.Textbox( + label="Caption", + lines=6, + interactive=True, + visible=True, + value=str(caption) + ), + None + ] + # Handle image files + elif is_image_file(file_path): + return [ + gr.Image( + label="Image Preview", + interactive=False, + visible=True, + value=str(file_path) + ), + gr.Video( + interactive=False, + visible=False + ), + gr.Textbox( + label="Caption", + lines=6, + interactive=True, + visible=True, + value=str(caption) + ), + None + ] + else: + return [ + gr.Image( + interactive=False, + visible=False + ), + gr.Video( + interactive=False, + visible=False + ), + gr.Textbox( + interactive=False, + visible=False + ), + f"Unsupported file type: {file_path.suffix}" + ] + except Exception as e: + logger.error(f"Error handling selection: {str(e)}") + return [ + gr.Image( + interactive=False, + visible=False + ), + gr.Video( + interactive=False, + visible=False + ), + gr.Textbox( + interactive=False, + visible=False + ), + f"Error handling selection: {str(e)}" + ] + + def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, prompt_prefix: str): + """Save changes to caption""" + try: + # Add prefix if not already present + if prompt_prefix and not preview_caption.startswith(prompt_prefix): + full_caption = f"{prompt_prefix}{preview_caption}" + else: + full_caption = preview_caption + + path = Path(preview_video if preview_video else preview_image) + if path.suffix == '.txt': + self.trainer.update_file_caption(path.with_suffix(''), full_caption) + else: + self.trainer.update_file_caption(path, full_caption) + return gr.update(value="Caption saved successfully!") + except Exception as e: + return gr.update(value=f"Error saving caption: {str(e)}") + + def get_model_info(self, model_type: str) -> str: + """Get information about the selected model type""" + if model_type == "hunyuan_video": + return """### HunyuanVideo (LoRA) + - Best for learning complex video generation patterns + - Required VRAM: ~47GB minimum + - Recommended batch size: 1-2 + - Typical training time: 2-4 hours + - Default resolution: 49x512x768 + - Default LoRA rank: 128""" + + elif model_type == "ltx_video": + return """### LTX-Video (LoRA) + - Lightweight video model + - Required VRAM: ~18GB minimum + - Recommended batch size: 1-4 + - Typical training time: 1-3 hours + - Default resolution: 49x512x768 + - Default LoRA rank: 128""" + + return "" + + def get_default_params(self, model_type: str) -> Dict[str, Any]: + """Get default training parameters for model type""" + if model_type == "hunyuan_video": + return { + "num_epochs": 70, + "batch_size": 1, + "learning_rate": 2e-5, + "save_iterations": 500, + "video_resolution_buckets": TRAINING_BUCKETS, + "video_reshape_mode": "center", + "caption_dropout_p": 0.05, + "gradient_accumulation_steps": 1, + "rank": 128, + "lora_alpha": 128 + } + else: # ltx_video + return { + "num_epochs": 70, + "batch_size": 1, + "learning_rate": 3e-5, + "save_iterations": 500, + "video_resolution_buckets": TRAINING_BUCKETS, + "video_reshape_mode": "center", + "caption_dropout_p": 0.05, + "gradient_accumulation_steps": 4, + "rank": 128, + "lora_alpha": 128 + } + + def preview_file(self, selected_text: str) -> Dict: + """Generate preview based on selected file + + Args: + selected_text: Text of the selected item containing filename + + Returns: + Dict with preview content for each preview component + """ + if not selected_text or "Caption:" in selected_text: + return { + "video": None, + "image": None, + "text": None + } + + # Extract filename from the preview text (remove size info) + filename = selected_text.split(" (")[0].strip() + file_path = TRAINING_VIDEOS_PATH / filename + + if not file_path.exists(): + return { + "video": None, + "image": None, + "text": f"File not found: {filename}" + } + + # Detect file type + mime_type, _ = mimetypes.guess_type(str(file_path)) + if not mime_type: + return { + "video": None, + "image": None, + "text": f"Unknown file type: {filename}" + } + + # Return appropriate preview + if mime_type.startswith('video/'): + return { + "video": str(file_path), + "image": None, + "text": None + } + elif mime_type.startswith('image/'): + return { + "video": None, + "image": str(file_path), + "text": None + } + elif mime_type.startswith('text/'): + try: + text_content = file_path.read_text() + return { + "video": None, + "image": None, + "text": text_content + } + except Exception as e: + return { + "video": None, + "image": None, + "text": f"Error reading file: {str(e)}" + } + else: + return { + "video": None, + "image": None, + "text": f"Unsupported file type: {mime_type}" + } + + def list_unprocessed_videos(self) -> gr.Dataframe: + """Update list of unprocessed videos""" + videos = self.splitter.list_unprocessed_videos() + # videos is already in [[name, status]] format from splitting_service + return gr.Dataframe( + headers=["name", "status"], + value=videos, + interactive=False + ) + + async def start_scene_detection(self, enable_splitting: bool) -> str: + """Start background scene detection process + + Args: + enable_splitting: Whether to split videos into scenes + """ + if self.splitter.is_processing(): + return "Scene detection already running" + + try: + await self.splitter.start_processing(enable_splitting) + return "Scene detection completed" + except Exception as e: + return f"Error during scene detection: {str(e)}" + + + def refresh_training_status_and_logs(self): + """Refresh all dynamic lists and training state""" + status = self.trainer.get_status() + logs = self.trainer.get_logs() + + status_update = status["message"] + + # Parse new log lines + if logs: + last_state = None + for line in logs.splitlines(): + state_update = self.log_parser.parse_line(line) + if state_update: + last_state = state_update + + if last_state: + ui_updates = self.update_training_ui(last_state) + status_update = ui_updates.get("status_box", status["message"]) + + return (status_update, logs) + + def refresh_training_status(self): + """Refresh training status and update UI""" + status, logs = self.refresh_training_status_and_logs() + + # Parse status for training state + is_completed = "completed" in status.lower() or "100.0%" in status + current_state = { + "status": "completed" if is_completed else "training", + "message": status + } + + if is_completed: + button_updates = self.handle_training_complete() + return ( + status, + logs, + *button_updates.values() + ) + + # Update based on current training state + button_updates = self.update_training_buttons(current_state) + return ( + status, + logs, + *button_updates.values() + ) + + def refresh_dataset(self): + """Refresh all dynamic lists and training state""" + video_list = self.splitter.list_unprocessed_videos() + training_dataset = self.list_training_files_to_caption() + + return ( + video_list, + training_dataset + ) + + def create_ui(self): + """Create Gradio interface""" + + with gr.Blocks(title="🎥 Video Model Studio") as app: + gr.Markdown("# 🎥 Video Model Studio") + + with gr.Tabs() as tabs: + with gr.TabItem("1️⃣ Import", id="import_tab"): + + with gr.Row(): + gr.Markdown("## Optional: automated data cleaning") + + with gr.Row(): + enable_automatic_video_split = gr.Checkbox( + label="Automatically split videos into smaller clips", + info="Note: a clip is a single camera shot, usually a few seconds", + value=True, + visible=False + ) + enable_automatic_content_captioning = gr.Checkbox( + label="Automatically caption photos and videos", + info="Note: this uses LlaVA and takes some extra time to load and process", + value=False, + visible=False, + ) + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Import video files") + gr.Markdown("You can upload either:") + gr.Markdown("- A single MP4 video file") + gr.Markdown("- A ZIP archive containing multiple videos and optional caption files") + gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)") + + with gr.Row(): + files = gr.Files( + label="Upload Images, Videos or ZIP", + #file_count="multiple", + file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"], + type="filepath" + ) + + with gr.Column(scale=3): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Import a YouTube video") + gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:") + + with gr.Row(): + youtube_url = gr.Textbox( + label="Import YouTube Video", + placeholder="https://www.youtube.com/watch?v=..." + ) + with gr.Row(): + youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary") + with gr.Row(): + import_status = gr.Textbox(label="Status", interactive=False) + + + with gr.TabItem("2️⃣ Split", id="split_tab"): + with gr.Row(): + split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)") + + with gr.Row(): + with gr.Column(): + detect_btn = gr.Button("Split videos into single-camera shots", variant="primary") + detect_status = gr.Textbox(label="Status", interactive=False) + + with gr.Column(): + + video_list = gr.Dataframe( + headers=["name", "status"], + label="Videos to split", + interactive=False, + wrap=True, + #selection_mode="cell" # Enable cell selection + ) + + + with gr.TabItem("3️⃣ Caption"): + with gr.Row(): + caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)") + + with gr.Row(): + + with gr.Column(): + with gr.Row(): + custom_prompt_prefix = gr.Textbox( + scale=3, + label='Prefix to add to ALL captions (eg. "In the style of TOK, ")', + placeholder="In the style of TOK, ", + lines=2, + value=DEFAULT_PROMPT_PREFIX + ) + captioning_bot_instructions = gr.Textbox( + scale=6, + label="System instructions for the automatic captioning model", + placeholder="Please generate a full description of...", + lines=5, + value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS + ) + with gr.Row(): + run_autocaption_btn = gr.Button( + "Automatically fill missing captions", + variant="primary" # Makes it green by default + ) + copy_files_to_training_dir_btn = gr.Button( + "Copy assets to training directory", + variant="primary" # Makes it green by default + ) + stop_autocaption_btn = gr.Button( + "Stop Captioning", + variant="stop", # Red when enabled + interactive=False # Disabled by default + ) + + with gr.Row(): + with gr.Column(): + training_dataset = gr.Dataframe( + headers=["name", "status"], + interactive=False, + wrap=True, + value=self.list_training_files_to_caption(), + row_count=10, # Optional: set a reasonable row count + #selection_mode="cell" + ) + + with gr.Column(): + preview_video = gr.Video( + label="Video Preview", + interactive=False, + visible=False + ) + preview_image = gr.Image( + label="Image Preview", + interactive=False, + visible=False + ) + preview_caption = gr.Textbox( + label="Caption", + lines=6, + interactive=True + ) + save_caption_btn = gr.Button("Save Caption") + preview_status = gr.Textbox( + label="Status", + interactive=False, + visible=True + ) + + with gr.TabItem("4️⃣ Train"): + with gr.Row(): + with gr.Column(): + + with gr.Row(): + train_title = gr.Markdown("## 0 files available for training (0 bytes)") + + with gr.Row(): + with gr.Column(): + model_type = gr.Dropdown( + choices=list(MODEL_TYPES.keys()), + label="Model Type", + value=list(MODEL_TYPES.keys())[0] + ) + model_info = gr.Markdown( + value=self.get_model_info(list(MODEL_TYPES.keys())[0]) + ) + + with gr.Row(): + lora_rank = gr.Dropdown( + label="LoRA Rank", + choices=["16", "32", "64", "128", "256"], + value="128", + type="value" + ) + lora_alpha = gr.Dropdown( + label="LoRA Alpha", + choices=["16", "32", "64", "128", "256"], + value="128", + type="value" + ) + with gr.Row(): + num_epochs = gr.Number( + label="Number of Epochs", + value=70, + minimum=1, + precision=0 + ) + batch_size = gr.Number( + label="Batch Size", + value=1, + minimum=1, + precision=0 + ) + with gr.Row(): + learning_rate = gr.Number( + label="Learning Rate", + value=2e-5, + minimum=1e-7 + ) + save_iterations = gr.Number( + label="Save checkpoint every N iterations", + value=500, + minimum=50, + precision=0, + info="Model will be saved periodically after these many steps" + ) + + with gr.Column(): + with gr.Row(): + start_btn = gr.Button( + "Start Training", + variant="primary", + interactive=not ASK_USER_TO_DUPLICATE_SPACE + ) + pause_resume_btn = gr.Button( + "Resume Training", + variant="secondary", + interactive=False + ) + stop_btn = gr.Button( + "Stop Training", + variant="stop", + interactive=False + ) + + with gr.Row(): + with gr.Column(): + status_box = gr.Textbox( + label="Training Status", + interactive=False, + lines=4 + ) + log_box = gr.TextArea( + label="Training Logs", + interactive=False, + lines=10, + max_lines=40, + autoscroll=True + ) + + with gr.TabItem("5️⃣ Manage"): + + with gr.Column(): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Publishing") + gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)") + + with gr.Row(): + + with gr.Column(): + repo_id = gr.Textbox( + label="HuggingFace Model Repository", + placeholder="username/model-name", + info="The repository will be created if it doesn't exist" + ) + gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"), + global_stop_btn = gr.Button( + "Push my model", + #variant="stop" + ) + + + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Storage management") + with gr.Row(): + download_dataset_btn = gr.DownloadButton( + "Download dataset", + variant="secondary", + size="lg" + ) + download_model_btn = gr.DownloadButton( + "Download model", + variant="secondary", + size="lg" + ) + + + with gr.Row(): + global_stop_btn = gr.Button( + "Stop everything and delete my data", + variant="stop" + ) + global_status = gr.Textbox( + label="Global Status", + interactive=False, + visible=False + ) + + + + # Event handlers + def update_model_info(model): + params = self.get_default_params(MODEL_TYPES[model]) + info = self.get_model_info(MODEL_TYPES[model]) + return { + model_info: info, + num_epochs: params["num_epochs"], + batch_size: params["batch_size"], + learning_rate: params["learning_rate"], + save_iterations: params["save_iterations"] + } + + def validate_repo(repo_id: str) -> dict: + validation = validate_model_repo(repo_id) + if validation["error"]: + return gr.update(value=repo_id, error=validation["error"]) + return gr.update(value=repo_id, error=None) + + # Connect events + model_type.change( + fn=update_model_info, + inputs=[model_type], + outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations] + ) + + async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix): + videos = self.list_unprocessed_videos() + # If scene detection isn't already running and there are videos to process, + # and auto-splitting is enabled, start the detection + if videos and not self.splitter.is_processing() and enable_splitting: + await self.start_scene_detection(enable_splitting) + msg = "Starting automatic scene detection..." + else: + # Just copy files without splitting if auto-split disabled + for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"): + await self.splitter.process_video(video_file, enable_splitting=False) + msg = "Copying videos without splitting..." + + copy_files_to_training_dir(prompt_prefix) + + # Start auto-captioning if enabled + if enable_automatic_content_captioning: + await self.start_caption_generation( + DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, + prompt_prefix + ) + + return { + tabs: gr.Tabs(selected="split_tab"), + video_list: videos, + detect_status: msg + } + + + async def update_titles_after_import(enable_splitting, enable_automatic_content_captioning, prompt_prefix): + """Handle post-import updates including titles""" + import_result = await on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix) + titles = self.update_titles() + return (*import_result, *titles) + + files.upload( + fn=lambda x: self.importer.process_uploaded_files(x), + inputs=[files], + outputs=[import_status] + ).success( + fn=update_titles_after_import, + inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix], + outputs=[ + tabs, video_list, detect_status, + split_title, caption_title, train_title + ] + ) + + youtube_download_btn.click( + fn=self.importer.download_youtube_video, + inputs=[youtube_url], + outputs=[import_status] + ).success( + fn=on_import_success, + inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix], + outputs=[tabs, video_list, detect_status] + ) + + # Scene detection events + detect_btn.click( + fn=self.start_scene_detection, + inputs=[enable_automatic_video_split], + outputs=[detect_status] + ) + + + # Update button states based on captioning status + def update_button_states(is_running): + return { + run_autocaption_btn: gr.Button( + interactive=not is_running, + variant="secondary" if is_running else "primary", + ), + stop_autocaption_btn: gr.Button( + interactive=is_running, + variant="secondary", + ), + } + + run_autocaption_btn.click( + fn=self.start_caption_generation, + inputs=[captioning_bot_instructions, custom_prompt_prefix], + outputs=[training_dataset], + ).then( + fn=lambda: update_button_states(True), + outputs=[run_autocaption_btn, stop_autocaption_btn] + ) + + copy_files_to_training_dir_btn.click( + fn=self.copy_files_to_training_dir, + inputs=[custom_prompt_prefix] + ) + + stop_autocaption_btn.click( + fn=lambda: (self.captioner.stop_captioning() if self.captioner else None, update_button_states(False)), + outputs=[run_autocaption_btn, stop_autocaption_btn] + ) + + training_dataset.select( + fn=self.handle_training_dataset_select, + outputs=[preview_image, preview_video, preview_caption, preview_status] + ) + + save_caption_btn.click( + fn=self.save_caption_changes, + inputs=[preview_caption, preview_image, preview_video, custom_prompt_prefix], + outputs=[preview_status] + ).success( + fn=self.list_training_files_to_caption, + outputs=[training_dataset] + ) + + # Training control events + start_btn.click( + fn=lambda model_type, *args: ( + self.log_parser.reset(), + self.trainer.start_training( + MODEL_TYPES[model_type], + *args + ) + ), + inputs=[ + model_type, + lora_rank, + lora_alpha, + num_epochs, + batch_size, + learning_rate, + save_iterations, + repo_id + ], + outputs=[status_box, log_box] + ).success( + fn=lambda: self.update_training_buttons({ + "status": "training" + }), + outputs=[start_btn, stop_btn, pause_resume_btn] + ) + + + pause_resume_btn.click( + fn=self.handle_pause_resume, + outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] + ) + + stop_btn.click( + fn=self.trainer.stop_training, + outputs=[status_box, log_box] + ).success( + fn=self.handle_training_complete, + outputs=[start_btn, stop_btn, pause_resume_btn] + ) + + def handle_global_stop(): + result = self.stop_all_and_clear() + # Update all relevant UI components + status = result["status"] + details = "\n".join(f"{k}: {v}" for k, v in result["details"].items()) + full_status = f"{status}\n\nDetails:\n{details}" + + # Get fresh lists after cleanup + videos = self.splitter.list_unprocessed_videos() + clips = self.list_training_files_to_caption() + + return { + global_status: gr.update(value=full_status, visible=True), + video_list: videos, + training_dataset: clips, + status_box: "Training stopped and data cleared", + log_box: "", + detect_status: "Scene detection stopped", + import_status: "All data cleared", + preview_status: "Captioning stopped" + } + + download_dataset_btn.click( + fn=self.trainer.create_training_dataset_zip, + outputs=[download_dataset_btn] + ) + + download_model_btn.click( + fn=self.trainer.get_model_output_safetensors, + outputs=[download_model_btn] + ) + + global_stop_btn.click( + fn=handle_global_stop, + outputs=[ + global_status, + video_list, + training_dataset, + status_box, + log_box, + detect_status, + import_status, + preview_status + ] + ) + + # Auto-refresh timers + app.load( + fn=lambda: ( + self.refresh_dataset() + ), + outputs=[ + video_list, training_dataset + ] + ) + + timer = gr.Timer(value=1) + timer.tick( + fn=lambda: ( + self.refresh_training_status_and_logs() + ), + outputs=[ + status_box, + log_box + ] + ) + + timer = gr.Timer(value=5) + timer.tick( + fn=lambda: ( + self.refresh_dataset() + ), + outputs=[ + video_list, training_dataset + ] + ) + + timer = gr.Timer(value=5) + timer.tick( + fn=lambda: self.update_titles(), + outputs=[ + split_title, caption_title, train_title + ] + ) + + return app + +def create_app(): + if ASK_USER_TO_DUPLICATE_SPACE: + with gr.Blocks() as app: + gr.Markdown("""# Finetrainers UI + +This Hugging Face space needs to be duplicated to your own billing account to work. + +Click the 'Duplicate Space' button at the top of the page to create your own copy. + +It is recommended to use a Nvidia L40S and a persistent storage space. +To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""") + return app + + ui = VideoTrainerUI() + return ui.create_ui() + +if __name__ == "__main__": + app = create_app() + + allowed_paths = [ + str(STORAGE_PATH), # Base storage + str(VIDEOS_TO_SPLIT_PATH), + str(STAGING_PATH), + str(TRAINING_PATH), + str(TRAINING_VIDEOS_PATH), + str(MODEL_PATH), + str(OUTPUT_PATH) + ] + app.queue(default_concurrency_limit=1).launch( + server_name="0.0.0.0", + allowed_paths=allowed_paths + ) \ No newline at end of file diff --git a/captioning_service.py b/captioning_service.py new file mode 100644 index 0000000000000000000000000000000000000000..09a09ec8a6d8bbe7af431e9e0c89cdfed91a15e1 --- /dev/null +++ b/captioning_service.py @@ -0,0 +1,534 @@ +import logging +import torch +import shutil +import gradio as gr +from llava.model.builder import load_pretrained_model +from llava.mm_utils import tokenizer_image_token +import numpy as np +from decord import VideoReader, cpu +from pathlib import Path +from typing import Any, Tuple, Dict, Optional, AsyncGenerator, List +import asyncio +from dataclasses import dataclass +from datetime import datetime +import cv2 +from config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX +from utils import extract_scene_info, is_image_file, is_video_file +from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset + +logger = logging.getLogger(__name__) + +@dataclass +class CaptioningProgress: + video_name: str + total_frames: int + processed_frames: int + status: str + started_at: datetime + completed_at: Optional[datetime] = None + error: Optional[str] = None + +class CaptioningService: + _instance = None + _model = None + _tokenizer = None + _image_processor = None + _model_loading = None + _loop = None + + def __new__(cls, model_name=CAPTIONING_MODEL): + if cls._instance is not None: + return cls._instance + + instance = super().__new__(cls) + if PRELOAD_CAPTIONING_MODEL: + cls._instance = instance + try: + cls._loop = asyncio.get_running_loop() + except RuntimeError: + cls._loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._loop) + + if not USE_MOCK_CAPTIONING_MODEL and cls._model_loading is None: + cls._model_loading = cls._loop.create_task(cls._background_load_model(model_name)) + return instance + + def __init__(self, model_name=CAPTIONING_MODEL): + if hasattr(self, 'model_name'): # Already initialized + return + + self.model_name = model_name + self.tokenizer = None + self.model = None + self.image_processor = None + self.active_tasks: Dict[str, CaptioningProgress] = {} + self._should_stop = False + self._model_loaded = False + + @classmethod + async def _background_load_model(cls, model_name): + """Background task to load the model""" + try: + logger.info("Starting background model loading...") + if not cls._loop: + cls._loop = asyncio.get_running_loop() + + def load_model(): + try: + tokenizer, model, image_processor, _ = load_pretrained_model( + model_name, None, "llava_qwen", + torch_dtype="bfloat16", device_map="auto" + ) + model.eval() + return tokenizer, model, image_processor + except Exception as e: + logger.error(f"Error in load_model: {str(e)}") + raise + + result = await cls._loop.run_in_executor(None, load_model) + + cls._tokenizer, cls._model, cls._image_processor = result + logger.info("Background model loading completed successfully!") + + except Exception as e: + logger.error(f"Background model loading failed: {str(e)}") + cls._model_loading = None + raise + + async def ensure_model_loaded(self): + """Ensure model is loaded before processing""" + if USE_MOCK_CAPTIONING_MODEL: + logger.info("Using mock model, skipping model loading") + self.__class__._model_loading = None + self._model_loaded = True + return + + if not self._model_loaded: + try: + if PRELOAD_CAPTIONING_MODEL and self.__class__._model_loading: + logger.info("Waiting for background model loading to complete...") + if self.__class__._loop and self.__class__._loop != asyncio.get_running_loop(): + logger.warning("Different event loop detected, creating new loading task") + self.__class__._model_loading = None + await self._load_model_sync() + else: + await self.__class__._model_loading + self.model = self.__class__._model + self.tokenizer = self.__class__._tokenizer + self.image_processor = self.__class__._image_processor + else: + await self._load_model_sync() + + self._model_loaded = True + logger.info("Model loading completed!") + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + raise + + async def _load_model_sync(self): + """Synchronously load the model""" + logger.info("Loading model synchronously...") + current_loop = asyncio.get_running_loop() + + def load_model(): + return load_pretrained_model( + self.model_name, None, "llava_qwen", + torch_dtype="bfloat16", device_map="auto" + ) + + self.tokenizer, self.model, self.image_processor, _ = await current_loop.run_in_executor( + None, load_model + ) + self.model.eval() + + def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> tuple[np.ndarray, str, float]: + """Load and preprocess video frames""" + + video_path_str = str(video_path) if hasattr(video_path, '__fspath__') else video_path + + logger.debug(f"Loading video: {video_path_str}") + + if max_frames_num == 0: + return np.zeros((1, 336, 336, 3)), "", 0 + + vr = VideoReader(video_path_str, ctx=cpu(0), num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + + # Calculate frame indices + fps = round(vr.get_avg_fps()/fps) + frame_idx = [i for i in range(0, len(vr), fps)] + frame_time = [i/fps for i in frame_idx] + + if len(frame_idx) > max_frames_num or force_sample: + sample_fps = max_frames_num + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i/vr.get_avg_fps() for i in frame_idx] + + frame_time_str = ",".join([f"{i:.2f}s" for i in frame_time]) + + try: + frames = vr.get_batch(frame_idx).asnumpy() + logger.debug(f"Loaded {len(frames)} frames with shape {frames.shape}") + return frames, frame_time_str, video_time + except Exception as e: + logger.error(f"Error loading video frames: {str(e)}") + raise + + async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]: + try: + video_name = video_path.name + logger.info(f"Starting processing of video: {video_name}") + + # Load video metadata + logger.debug(f"Loading video metadata for {video_name}") + loop = asyncio.get_event_loop() + vr = await loop.run_in_executor(None, lambda: VideoReader(str(video_path), ctx=cpu(0))) + total_frames = len(vr) + + progress = CaptioningProgress( + video_name=video_name, + total_frames=total_frames, + processed_frames=0, + status="initializing", + started_at=datetime.now() + ) + self.active_tasks[video_name] = progress + yield progress, None + + # Get parent caption if this is a clip + parent_caption = "" + if "___" in video_path.stem: + parent_name, _ = extract_scene_info(video_path.stem) + #print(f"parent_name is {parent_name}") + parent_txt_path = VIDEOS_TO_SPLIT_PATH / f"{parent_name}.txt" + if parent_txt_path.exists(): + logger.debug(f"Found parent caption file: {parent_txt_path}") + parent_caption = parent_txt_path.read_text().strip() + + # Ensure model is loaded before processing + await self.ensure_model_loaded() + + if USE_MOCK_CAPTIONING_MODEL: + + # Even in mock mode, we'll generate a caption that shows we processed parent info + clip_caption = f"This is a test caption for {video_name}" + + # Combine clip caption with parent caption + if parent_caption and not full_caption.endswith(parent_caption): + #print(f"we have parent_caption, so we define the full_caption as {clip_caption}\n{parent_caption}") + + full_caption = f"{clip_caption}\n{parent_caption}" + else: + #print(f"we don't have a parent_caption, so we define the full_caption as {clip_caption}") + + full_caption = clip_caption + + if prompt_prefix and not full_caption.startswith(prompt_prefix): + full_caption = f"{prompt_prefix}{full_caption}" + + # Write the caption file + txt_path = video_path.with_suffix('.txt') + txt_path.write_text(full_caption) + + logger.debug(f"Mock mode: Saved caption to {txt_path}") + + progress.status = "completed" + progress.processed_frames = total_frames + progress.completed_at = datetime.now() + yield progress, full_caption + + else: + # Process frames in batches + max_frames_num = 64 + frames, frame_times_str, video_time = await loop.run_in_executor( + None, + lambda: self._load_video(video_path, max_frames_num) + ) + + # Process all frames at once using the image processor + processed_frames = await loop.run_in_executor( + None, + lambda: self.image_processor.preprocess( + frames, + return_tensors="pt" + )["pixel_values"] + ) + + # Update progress + progress.processed_frames = len(frames) + progress.status = "generating caption" + yield progress, None + + # Move processed frames to GPU + video_tensor = processed_frames.to('cuda').bfloat16() + + time_instruction = (f"The video lasts for {video_time:.2f} seconds, and {len(frames)} " + f"frames are uniformly sampled from it. These frames are located at {frame_times_str}.") + full_prompt = f"{time_instruction}\n{prompt}" + + input_ids = await loop.run_in_executor( + None, + lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda') + ) + + # Generate caption + with torch.no_grad(): + output = await loop.run_in_executor( + None, + lambda: self.model.generate( + input_ids, + images=[video_tensor], + modalities=["video"], + do_sample=False, + temperature=0, + max_new_tokens=4096, + ) + ) + + clip_caption = await loop.run_in_executor( + None, + lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() + ) + + # Combine clip caption with parent caption + if parent_caption: + print(f"we have parent_caption, so we define the full_caption as {clip_caption}\n{parent_caption}") + + full_caption = f"{clip_caption}\n{parent_caption}" + else: + print(f"we don't have a parent_caption, so we define the full_caption as {clip_caption}") + + full_caption = clip_caption + + if prompt_prefix: + full_caption = f"{prompt_prefix}{full_caption}" + + + # Write the caption file + txt_path = video_path.with_suffix('.txt') + txt_path.write_text(full_caption) + + progress.status = "completed" + progress.completed_at = datetime.now() + gr.Info(f"Successfully generated caption for {video_name}") + yield progress, full_caption + + except Exception as e: + progress.status = "error" + progress.error = str(e) + progress.completed_at = datetime.now() + yield progress, None + raise gr.Error(f"Error processing video: {str(e)}") + + async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]: + """Process a single image for captioning""" + try: + image_name = image_path.name + logger.info(f"Starting processing of image: {image_name}") + + progress = CaptioningProgress( + video_name=image_name, # Reusing video_name field for images + total_frames=1, + processed_frames=0, + status="initializing", + started_at=datetime.now() + ) + self.active_tasks[image_name] = progress + yield progress, None + + # Ensure model is loaded + await self.ensure_model_loaded() + + if USE_MOCK_CAPTIONING_MODEL: + progress.status = "completed" + progress.processed_frames = 1 + progress.completed_at = datetime.now() + print("yielding fake") + yield progress, "This is a test image caption" + return + + # Read and process image + loop = asyncio.get_event_loop() + image = await loop.run_in_executor( + None, + lambda: cv2.imread(str(image_path)) + ) + if image is None: + raise ValueError(f"Could not read image: {str(image_path)}") + + # Convert BGR to RGB + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Process image + processed_image = await loop.run_in_executor( + None, + lambda: self.image_processor.preprocess( + image, + return_tensors="pt" + )["pixel_values"] + ) + + progress.processed_frames = 1 + progress.status = "generating caption" + yield progress, None + + # Move to GPU and generate caption + image_tensor = processed_image.to('cuda').bfloat16() + full_prompt = f"{prompt}" + + input_ids = await loop.run_in_executor( + None, + lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda') + ) + + with torch.no_grad(): + output = await loop.run_in_executor( + None, + lambda: self.model.generate( + input_ids, + images=[image_tensor], + modalities=["image"], + do_sample=False, + temperature=0, + max_new_tokens=4096, + ) + ) + + caption = await loop.run_in_executor( + None, + lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() + ) + + progress.status = "completed" + progress.completed_at = datetime.now() + gr.Info(f"Successfully generated caption for {image_name}") + yield progress, caption + + except Exception as e: + progress.status = "error" + progress.error = str(e) + progress.completed_at = datetime.now() + yield progress, None + raise gr.Error(f"Error processing image: {str(e)}") + + + async def start_caption_generation(self, custom_prompt: str, prompt_prefix: str) -> AsyncGenerator[List[List[str]], None]: + """Iterates over clips to auto-generate captions asynchronously.""" + try: + logger.info("Starting auto-caption generation") + + # Use provided prompt or default + default_prompt = DEFAULT_CAPTIONING_BOT_INSTRUCTIONS + prompt = custom_prompt.strip() or default_prompt + logger.debug(f"Using prompt: {prompt}") + + # Find files needing captions + video_files = list(STAGING_PATH.glob("*.mp4")) + image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)] + all_files = video_files + image_files + + # Filter for files missing captions or with empty caption files + files_to_process = [] + for file_path in all_files: + caption_path = file_path.with_suffix('.txt') + needs_caption = ( + not caption_path.exists() or + caption_path.stat().st_size == 0 or + caption_path.read_text().strip() == "" + ) + if needs_caption: + files_to_process.append(file_path) + + logger.info(f"Found {len(files_to_process)} files needing captions") + + if not files_to_process: + logger.info("No files need captioning") + yield [] + return + + self._should_stop = False + self.active_tasks.clear() + status_update: Dict[str, Dict[str, Any]] = {} + + for file_path in all_files: + if self._should_stop: + break + + try: + print(f"we are in file_path {str(file_path)}") + # Choose appropriate processing method based on file type + if is_video_file(file_path): + process_gen = self.process_video(file_path, prompt, prompt_prefix) + else: + process_gen = self.process_image(file_path, prompt, prompt_prefix) + print("got process_gen = ", process_gen) + async for progress, caption in process_gen: + print(f"process_gen contains this caption = {caption}") + if caption and prompt_prefix and not caption.startswith(prompt_prefix): + caption = f"{prompt_prefix}{caption}" + + # Save caption + if caption: + txt_path = file_path.with_suffix('.txt') + txt_path.write_text(caption) + + logger.debug(f"Progress update: {progress.status}") + + # Store progress info + status_update[file_path.name] = { + "status": progress.status, + "frames": progress.processed_frames, + "total": progress.total_frames + } + + # Convert to list format for Gradio DataFrame + rows = [] + for file_name, info in status_update.items(): + status = info["status"] + if status == "processing": + percent = (info["frames"] / info["total"]) * 100 + status = f"Analyzing... {percent:.1f}% ({info['frames']}/{info['total']} frames)" + elif status == "generating caption": + status = "Generating caption..." + elif status == "error": + status = f"Error: {progress.error}" + elif status == "completed": + status = "Completed" + + rows.append([file_name, status]) + + yield rows + await asyncio.sleep(0.1) + + + except Exception as e: + logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True) + rows = [[str(file_path.name), f"Error: {str(e)}"]] + yield rows + continue + + logger.info("Auto-caption generation completed, cyping assets to the training dir..") + + copy_files_to_training_dir(prompt_prefix) + except Exception as e: + logger.error(f"Error in start_caption_generation: {str(e)}") + yield [[str(e), "error"]] + raise + + def stop_captioning(self): + """Stop all ongoing captioning tasks""" + logger.info("Stopping all captioning tasks") + self._should_stop = True + + def close(self): + """Clean up resources""" + logger.info("Cleaning up captioning service resources") + if hasattr(self, 'model'): + del self.model + if hasattr(self, 'tokenizer'): + del self.tokenizer + if hasattr(self, 'image_processor'): + del self.image_processor + torch.cuda.empty_cache() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..06832063a7bbe32ebc7b1554c8efcc2cdc049498 --- /dev/null +++ b/config.py @@ -0,0 +1,303 @@ +import os +from dataclasses import dataclass, field +from typing import Dict, Any, Optional, List, Tuple +from pathlib import Path +from utils import parse_bool_env + +HF_API_TOKEN = os.getenv("HF_API_TOKEN") +ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE")) + +# Base storage path +STORAGE_PATH = Path(os.environ.get('STORAGE_PATH', '.data')) + +# Subdirectories for different data types +VIDEOS_TO_SPLIT_PATH = STORAGE_PATH / "videos_to_split" # Raw uploaded/downloaded files +STAGING_PATH = STORAGE_PATH / "staging" # This is where files that are captioned or need captioning are waiting +TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset +TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training +MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files +OUTPUT_PATH = STORAGE_PATH / "output" # Training outputs and logs + +# On the production server we can afford to preload the big model +PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL')) + +CAPTIONING_MODEL = "lmms-lab/LLaVA-Video-7B-Qwen2" + +DEFAULT_PROMPT_PREFIX = "In the style of TOK, " + +# This is only use to debug things in local +USE_MOCK_CAPTIONING_MODEL = parse_bool_env(os.environ.get('USE_MOCK_CAPTIONING_MODEL')) + +DEFAULT_CAPTIONING_BOT_INSTRUCTIONS = "Please write a full description of the following video: camera (close-up shot, medium-shot..), genre (music video, horror movie scene, video game footage, go pro footage, japanese anime, noir film, science-fiction, action movie, documentary..), characters (physical appearance, look, skin, facial features, haircut, clothing), scene (action, positions, movements), location (indoor, outdoor, place, building, country..), time and lighting (natural, golden hour, night time, LED lights, kelvin temperature etc), weather and climate (dusty, rainy, fog, haze, snowing..), era/settings" + +# Create directories +STORAGE_PATH.mkdir(parents=True, exist_ok=True) +VIDEOS_TO_SPLIT_PATH.mkdir(parents=True, exist_ok=True) +STAGING_PATH.mkdir(parents=True, exist_ok=True) +TRAINING_PATH.mkdir(parents=True, exist_ok=True) +TRAINING_VIDEOS_PATH.mkdir(parents=True, exist_ok=True) +MODEL_PATH.mkdir(parents=True, exist_ok=True) +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) + +# Image normalization settings +NORMALIZE_IMAGES_TO = os.environ.get('NORMALIZE_IMAGES_TO', 'png').lower() +if NORMALIZE_IMAGES_TO not in ['png', 'jpg']: + raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'") +JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97')) + +MODEL_TYPES = { + "HunyuanVideo (LoRA)": "hunyuan_video", + "LTX-Video (LoRA)": "ltx_video" +} + + +# it is best to use resolutions that are powers of 8 +# The resolution should be divisible by 32 +# so we cannot use 1080, 540 etc as they are not divisible by 32 +TRAINING_WIDTH = 768 # 32 * 24 +TRAINING_HEIGHT = 512 # 32 * 16 + +# 1920 = 32 * 60 (divided by 2: 960 = 32 * 30) +# 1920 = 32 * 60 (divided by 2: 960 = 32 * 30) +# 1056 = 32 * 33 (divided by 2: 544 = 17 * 32) +# 1024 = 32 * 32 (divided by 2: 512 = 16 * 32) +# it is important that the resolution buckets properly cover the training dataset, +# or else that we exclude from the dataset videos that are out of this range +# right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here + +TRAINING_BUCKETS = [ + (8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1 + (8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1 + (8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1 + (8 * 8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 64 + 1 + (8 * 10 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 80 + 1 + (8 * 12 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 96 + 1 + (8 * 14 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 112 + 1 + (8 * 16 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 128 + 1 + (8 * 18 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 144 + 1 + (8 * 20 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 160 + 1 + (8 * 22 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 176 + 1 + (8 * 24 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 192 + 1 + (8 * 28 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 224 + 1 + (8 * 32 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 256 + 1 +] + +@dataclass +class TrainingConfig: + """Configuration class for finetrainers training""" + + # Required arguments must come first + model_name: str + pretrained_model_name_or_path: str + data_root: str + output_dir: str + + # Optional arguments follow + revision: Optional[str] = None + variant: Optional[str] = None + cache_dir: Optional[str] = None + + # Dataset arguments + + # note: video_column and caption_column serve a dual purpose, + # when using the CSV mode they have to be CSV column names, + # otherwise they have to be filename (relative to the data_root dir path) + video_column: str = "videos.txt" + caption_column: str = "prompts.txt" + + id_token: Optional[str] = None + video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: TRAINING_BUCKETS) + video_reshape_mode: str = "center" + caption_dropout_p: float = 0.05 + caption_dropout_technique: str = "empty" + precompute_conditions: bool = False + + # Diffusion arguments + flow_resolution_shifting: bool = False + flow_weighting_scheme: str = "none" + flow_logit_mean: float = 0.0 + flow_logit_std: float = 1.0 + flow_mode_scale: float = 1.29 + + # Training arguments + training_type: str = "lora" + seed: int = 42 + mixed_precision: str = "bf16" + batch_size: int = 1 + train_epochs: int = 70 + lora_rank: int = 128 + lora_alpha: int = 128 + target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"]) + gradient_accumulation_steps: int = 1 + gradient_checkpointing: bool = True + checkpointing_steps: int = 500 + checkpointing_limit: Optional[int] = 2 + resume_from_checkpoint: Optional[str] = None + enable_slicing: bool = True + enable_tiling: bool = True + + # Optimizer arguments + optimizer: str = "adamw" + lr: float = 3e-5 + scale_lr: bool = False + lr_scheduler: str = "constant_with_warmup" + lr_warmup_steps: int = 100 + lr_num_cycles: int = 1 + lr_power: float = 1.0 + beta1: float = 0.9 + beta2: float = 0.95 + weight_decay: float = 1e-4 + epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + + # Miscellaneous arguments + tracker_name: str = "finetrainers" + report_to: str = "wandb" + nccl_timeout: int = 1800 + + @classmethod + def hunyuan_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig': + """Configuration for Hunyuan video-to-video LoRA training""" + return cls( + model_name="hunyuan_video", + pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo", + data_root=data_path, + output_dir=output_path, + batch_size=1, + train_epochs=70, + lr=2e-5, + gradient_checkpointing=True, + id_token="afkx", + gradient_accumulation_steps=1, + lora_rank=128, + lora_alpha=128, + video_resolution_buckets=TRAINING_BUCKETS, + caption_dropout_p=0.05, + flow_weighting_scheme="none" # Hunyuan specific + ) + + @classmethod + def ltx_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig': + """Configuration for LTX-Video LoRA training""" + return cls( + model_name="ltx_video", + pretrained_model_name_or_path="Lightricks/LTX-Video", + data_root=data_path, + output_dir=output_path, + batch_size=1, + train_epochs=70, + lr=3e-5, + gradient_checkpointing=True, + id_token="BW_STYLE", + gradient_accumulation_steps=4, + lora_rank=128, + lora_alpha=128, + video_resolution_buckets=TRAINING_BUCKETS, + caption_dropout_p=0.05, + flow_weighting_scheme="logit_normal" # LTX specific + ) + + def to_args_list(self) -> List[str]: + """Convert config to command line arguments list""" + args = [] + + # Model arguments + + # Add model_name (required argument) + args.extend(["--model_name", self.model_name]) + + args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path]) + if self.revision: + args.extend(["--revision", self.revision]) + if self.variant: + args.extend(["--variant", self.variant]) + if self.cache_dir: + args.extend(["--cache_dir", self.cache_dir]) + + # Dataset arguments + args.extend(["--data_root", self.data_root]) + args.extend(["--video_column", self.video_column]) + args.extend(["--caption_column", self.caption_column]) + if self.id_token: + args.extend(["--id_token", self.id_token]) + + # Add video resolution buckets + if self.video_resolution_buckets: + bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets] + args.extend(["--video_resolution_buckets"] + bucket_strs) + + if self.video_reshape_mode: + args.extend(["--video_reshape_mode", self.video_reshape_mode]) + + args.extend(["--caption_dropout_p", str(self.caption_dropout_p)]) + args.extend(["--caption_dropout_technique", self.caption_dropout_technique]) + if self.precompute_conditions: + args.append("--precompute_conditions") + + # Diffusion arguments + if self.flow_resolution_shifting: + args.append("--flow_resolution_shifting") + args.extend(["--flow_weighting_scheme", self.flow_weighting_scheme]) + args.extend(["--flow_logit_mean", str(self.flow_logit_mean)]) + args.extend(["--flow_logit_std", str(self.flow_logit_std)]) + args.extend(["--flow_mode_scale", str(self.flow_mode_scale)]) + + # Training arguments + args.extend(["--training_type", self.training_type]) + args.extend(["--seed", str(self.seed)]) + + # we don't use this, because mixed precision is handled by accelerate launch, not by the training script itself. + #args.extend(["--mixed_precision", self.mixed_precision]) + + args.extend(["--batch_size", str(self.batch_size)]) + args.extend(["--train_epochs", str(self.train_epochs)]) + args.extend(["--rank", str(self.lora_rank)]) + args.extend(["--lora_alpha", str(self.lora_alpha)]) + args.extend(["--target_modules"] + self.target_modules) + args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)]) + if self.gradient_checkpointing: + args.append("--gradient_checkpointing") + args.extend(["--checkpointing_steps", str(self.checkpointing_steps)]) + if self.checkpointing_limit: + args.extend(["--checkpointing_limit", str(self.checkpointing_limit)]) + if self.resume_from_checkpoint: + args.extend(["--resume_from_checkpoint", self.resume_from_checkpoint]) + if self.enable_slicing: + args.append("--enable_slicing") + if self.enable_tiling: + args.append("--enable_tiling") + + # Optimizer arguments + args.extend(["--optimizer", self.optimizer]) + args.extend(["--lr", str(self.lr)]) + if self.scale_lr: + args.append("--scale_lr") + args.extend(["--lr_scheduler", self.lr_scheduler]) + args.extend(["--lr_warmup_steps", str(self.lr_warmup_steps)]) + args.extend(["--lr_num_cycles", str(self.lr_num_cycles)]) + args.extend(["--lr_power", str(self.lr_power)]) + args.extend(["--beta1", str(self.beta1)]) + args.extend(["--beta2", str(self.beta2)]) + args.extend(["--weight_decay", str(self.weight_decay)]) + args.extend(["--epsilon", str(self.epsilon)]) + args.extend(["--max_grad_norm", str(self.max_grad_norm)]) + + # Miscellaneous arguments + args.extend(["--tracker_name", self.tracker_name]) + args.extend(["--output_dir", self.output_dir]) + args.extend(["--report_to", self.report_to]) + args.extend(["--nccl_timeout", str(self.nccl_timeout)]) + + # normally this is disabled by default, but there was a bug in finetrainers + # so I had to fix it in trainer.py to make sure we check for push_to-hub + #args.append("--push_to_hub") + #args.extend(["--hub_token", str(False)]) + #args.extend(["--hub_model_id", str(False)]) + + # If you are using LLM-captioned videos, it is common to see many unwanted starting phrases like + # "In this video, ...", "This video features ...", etc. + # To remove a simple subset of these phrases, you can specify + # --remove_common_llm_caption_prefixes when starting training. + args.append("--remove_common_llm_caption_prefixes") + + return args \ No newline at end of file diff --git a/finetrainers/__init__.py b/finetrainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..412e298eb519f037e08dc755f92d136cfe2ef2e6 --- /dev/null +++ b/finetrainers/__init__.py @@ -0,0 +1,2 @@ +from .args import Args, parse_arguments +from .trainer import Trainer diff --git a/finetrainers/args.py b/finetrainers/args.py new file mode 100644 index 0000000000000000000000000000000000000000..46cd04cca1c0d7368be8395ce3382bc28fc6865b --- /dev/null +++ b/finetrainers/args.py @@ -0,0 +1,1191 @@ +import argparse +import sys +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS +from .models import SUPPORTED_MODEL_CONFIGS + + +class Args: + r""" + The arguments for the finetrainers training script. + + For helpful information about arguments, run `python train.py --help`. + + TODO(aryan): add `python train.py --recommend_configs --model_name ` to recommend + good training configs for a model after extensive testing. + TODO(aryan): add `python train.py --memory_requirements --model_name ` to show + memory requirements per model, per training type with sensible training settings. + + MODEL ARGUMENTS + --------------- + model_name (`str`): + Name of model to train. To get a list of models, run `python train.py --list_models`. + pretrained_model_name_or_path (`str`): + Path to pretrained model or model identifier from https://huggingface.co/models. The model should be + loadable based on specified `model_name`. + revision (`str`, defaults to `None`): + If provided, the model will be loaded from a specific branch of the model repository. + variant (`str`, defaults to `None`): + Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk + storage requirements. + cache_dir (`str`, defaults to `None`): + The directory where the downloaded models and datasets will be stored, or loaded from. + text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder when generating text embeddings. + text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder 2 when generating text embeddings. + text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder 3 when generating text embeddings. + transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the transformer model. + vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the VAE model. + layerwise_upcasting_modules (`List[str]`, defaults to `[]`): + Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer']. + layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`): + Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2']. + layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`): + Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision + naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers + by default, and recommend adding more layers to the default list based on the model architecture. + + DATASET ARGUMENTS + ----------------- + data_root (`str`): + A folder containing the training data. + dataset_file (`str`, defaults to `None`): + Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using + a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example. + video_column (`str`): + The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the + line-separated path to video data. + caption_column (`str`): + The column of the dataset containing the instance prompt for each video. Or, the name of the file in + `data_root` folder containing the line-separated instance prompts. + id_token (`str`, defaults to `None`): + Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training. + image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`): + Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the + resolution (height, width) of the image. All images will be resized to the nearest bucket resolution. + video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`): + Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the + resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket + resolution. + video_reshape_mode (`str`, defaults to `None`): + All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']. + TODO(aryan): We don't support this. + caption_dropout_p (`float`, defaults to `0.00`): + Probability of dropout for the caption tokens. This is useful to improve the unconditional generation + quality of the model. + caption_dropout_technique (`str`, defaults to `empty`): + Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout + by setting the prompt condition to an empty string, while others zero-out the text embedding tensors. + precompute_conditions (`bool`, defaults to `False`): + Whether or not to precompute the conditionings for the model. This is useful for faster training, and + reduces the memory requirements. + remove_common_llm_caption_prefixes (`bool`, defaults to `False`): + Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the + generated text. + + DATALOADER_ARGUMENTS + -------------------- + See https://pytorch.org/docs/stable/data.html for more information. + + dataloader_num_workers (`int`, defaults to `0`): + Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner + on the main process. + pin_memory (`bool`, defaults to `False`): + Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading. + + DIFFUSION ARGUMENTS + ------------------- + flow_resolution_shifting (`bool`, defaults to `False`): + Resolution-dependent shifting of timestep schedules. + [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206). + TODO(aryan): We don't support this yet. + flow_base_seq_len (`int`, defaults to `256`): + Base number of tokens for images/video when applying resolution-dependent shifting. + flow_max_seq_len (`int`, defaults to `4096`): + Maximum number of tokens for images/video when applying resolution-dependent shifting. + flow_base_shift (`float`, defaults to `0.5`): + Base shift for timestep schedules when applying resolution-dependent shifting. + flow_max_shift (`float`, defaults to `1.15`): + Maximum shift for timestep schedules when applying resolution-dependent shifting. + flow_shift (`float`, defaults to `1.0`): + Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma). + Setting it higher is helpful when trying to train models for high-resolution generation or to produce better + samples in lower number of inference steps. + flow_weighting_scheme (`str`, defaults to `none`): + We default to the "none" weighting scheme for uniform sampling and uniform loss. + Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none']. + flow_logit_mean (`float`, defaults to `0.0`): + Mean to use when using the `'logit_normal'` weighting scheme. + flow_logit_std (`float`, defaults to `1.0`): + Standard deviation to use when using the `'logit_normal'` weighting scheme. + flow_mode_scale (`float`, defaults to `1.29`): + Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. + + TRAINING ARGUMENTS + ------------------ + training_type (`str`, defaults to `None`): + Type of training to perform. Choose between ['lora']. + seed (`int`, defaults to `42`): + A seed for reproducible training. + batch_size (`int`, defaults to `1`): + Per-device batch size. + train_epochs (`int`, defaults to `1`): + Number of training epochs. + train_steps (`int`, defaults to `None`): + Total number of training steps to perform. If provided, overrides `train_epochs`. + rank (`int`, defaults to `128`): + The rank for LoRA matrices. + lora_alpha (`float`, defaults to `64`): + The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices. + target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`): + The target modules for LoRA. Make sure to modify this based on the model. + gradient_accumulation_steps (`int`, defaults to `1`): + Number of gradients steps to accumulate before performing an optimizer step. + gradient_checkpointing (`bool`, defaults to `False`): + Whether or not to use gradient/activation checkpointing to save memory at the expense of slower + backward pass. + checkpointing_steps (`int`, defaults to `500`): + Save a checkpoint of the training state every X training steps. These checkpoints can be used both + as final checkpoints in case they are better than the last checkpoint, and are also suitable for + resuming training using `resume_from_checkpoint`. + checkpointing_limit (`int`, defaults to `None`): + Max number of checkpoints to store. + resume_from_checkpoint (`str`, defaults to `None`): + Whether training should be resumed from a previous checkpoint. Use a path saved by `checkpointing_steps`, + or `"latest"` to automatically select the last available checkpoint. + + OPTIMIZER ARGUMENTS + ------------------- + optimizer (`str`, defaults to `adamw`): + The optimizer type to use. Choose between ['adam', 'adamw']. + use_8bit_bnb (`bool`, defaults to `False`): + Whether to use 8bit variant of the `optimizer` using `bitsandbytes`. + lr (`float`, defaults to `1e-4`): + Initial learning rate (after the potential warmup period) to use. + scale_lr (`bool`, defaults to `False`): + Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size. + lr_scheduler (`str`, defaults to `cosine_with_restarts`): + The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', + 'constant', 'constant_with_warmup']. + lr_warmup_steps (`int`, defaults to `500`): + Number of steps for the warmup in the lr scheduler. + lr_num_cycles (`int`, defaults to `1`): + Number of hard resets of the lr in cosine_with_restarts scheduler. + lr_power (`float`, defaults to `1.0`): + Power factor of the polynomial scheduler. + beta1 (`float`, defaults to `0.9`): + beta2 (`float`, defaults to `0.95`): + beta3 (`float`, defaults to `0.999`): + weight_decay (`float`, defaults to `0.0001`): + Penalty for large weights in the model. + epsilon (`float`, defaults to `1e-8`): + Small value to avoid division by zero in the optimizer. + max_grad_norm (`float`, defaults to `1.0`): + Maximum gradient norm to clip the gradients. + + VALIDATION ARGUMENTS + -------------------- + validation_prompts (`List[str]`, defaults to `None`): + List of prompts to use for validation. If not provided, a random prompt will be selected from the training + dataset. + validation_images (`List[str]`, defaults to `None`): + List of image paths to use for validation. + validation_videos (`List[str]`, defaults to `None`): + List of video paths to use for validation. + validation_heights (`List[int]`, defaults to `None`): + List of heights for the validation videos. + validation_widths (`List[int]`, defaults to `None`): + List of widths for the validation videos. + validation_num_frames (`List[int]`, defaults to `None`): + List of number of frames for the validation videos. + num_validation_videos_per_prompt (`int`, defaults to `1`): + Number of videos to use for validation per prompt. + validation_every_n_epochs (`int`, defaults to `None`): + Perform validation every `n` training epochs. + validation_every_n_steps (`int`, defaults to `None`): + Perform validation every `n` training steps. + enable_model_cpu_offload (`bool`, defaults to `False`): + Whether or not to offload different modeling components to CPU during validation. + validation_frame_rate (`int`, defaults to `25`): + Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline. + + MISCELLANEOUS ARGUMENTS + ----------------------- + tracker_name (`str`, defaults to `finetrainers`): + Name of the tracker/project to use for logging training metrics. + push_to_hub (`bool`, defaults to `False`): + Whether or not to push the model to the Hugging Face Hub. + hub_token (`str`, defaults to `None`): + The API token to use for pushing the model to the Hugging Face Hub. + hub_model_id (`str`, defaults to `None`): + The model identifier to use for pushing the model to the Hugging Face Hub. + output_dir (`str`, defaults to `None`): + The directory where the model checkpoints and logs will be stored. + logging_dir (`str`, defaults to `logs`): + The directory where the logs will be stored. + allow_tf32 (`bool`, defaults to `False`): + Whether or not to allow the use of TF32 matmul on compatible hardware. + nccl_timeout (`int`, defaults to `1800`): + Timeout for the NCCL communication. + report_to (`str`, defaults to `wandb`): + The name of the logger to use for logging training metrics. Choose between ['wandb']. + """ + + # Model arguments + model_name: str = None + pretrained_model_name_or_path: str = None + revision: Optional[str] = None + variant: Optional[str] = None + cache_dir: Optional[str] = None + text_encoder_dtype: torch.dtype = torch.bfloat16 + text_encoder_2_dtype: torch.dtype = torch.bfloat16 + text_encoder_3_dtype: torch.dtype = torch.bfloat16 + transformer_dtype: torch.dtype = torch.bfloat16 + vae_dtype: torch.dtype = torch.bfloat16 + layerwise_upcasting_modules: List[str] = [] + layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn + layerwise_upcasting_skip_modules_pattern: List[str] = [ + "patch_embed", + "pos_embed", + "x_embedder", + "context_embedder", + "time_embed", + "^proj_in$", + "^proj_out$", + "norm", + ] + + # Dataset arguments + data_root: str = None + dataset_file: Optional[str] = None + video_column: str = None + caption_column: str = None + id_token: Optional[str] = None + image_resolution_buckets: List[Tuple[int, int]] = None + video_resolution_buckets: List[Tuple[int, int, int]] = None + video_reshape_mode: Optional[str] = None + caption_dropout_p: float = 0.00 + caption_dropout_technique: str = "empty" + precompute_conditions: bool = False + remove_common_llm_caption_prefixes: bool = False + + # Dataloader arguments + dataloader_num_workers: int = 0 + pin_memory: bool = False + + # Diffusion arguments + flow_resolution_shifting: bool = False + flow_base_seq_len: int = 256 + flow_max_seq_len: int = 4096 + flow_base_shift: float = 0.5 + flow_max_shift: float = 1.15 + flow_shift: float = 1.0 + flow_weighting_scheme: str = "none" + flow_logit_mean: float = 0.0 + flow_logit_std: float = 1.0 + flow_mode_scale: float = 1.29 + + # Training arguments + training_type: str = None + seed: int = 42 + batch_size: int = 1 + train_epochs: int = 1 + train_steps: int = None + rank: int = 128 + lora_alpha: float = 64 + target_modules: List[str] = ["to_k", "to_q", "to_v", "to_out.0"] + gradient_accumulation_steps: int = 1 + gradient_checkpointing: bool = False + checkpointing_steps: int = 500 + checkpointing_limit: Optional[int] = None + resume_from_checkpoint: Optional[str] = None + enable_slicing: bool = False + enable_tiling: bool = False + + # Optimizer arguments + optimizer: str = "adamw" + use_8bit_bnb: bool = False + lr: float = 1e-4 + scale_lr: bool = False + lr_scheduler: str = "cosine_with_restarts" + lr_warmup_steps: int = 0 + lr_num_cycles: int = 1 + lr_power: float = 1.0 + beta1: float = 0.9 + beta2: float = 0.95 + beta3: float = 0.999 + weight_decay: float = 0.0001 + epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + + # Validation arguments + validation_prompts: List[str] = None + validation_images: List[str] = None + validation_videos: List[str] = None + validation_heights: List[int] = None + validation_widths: List[int] = None + validation_num_frames: List[int] = None + num_validation_videos_per_prompt: int = 1 + validation_every_n_epochs: Optional[int] = None + validation_every_n_steps: Optional[int] = None + enable_model_cpu_offload: bool = False + validation_frame_rate: int = 25 + + # Miscellaneous arguments + tracker_name: str = "finetrainers" + push_to_hub: bool = False + hub_token: Optional[str] = None + hub_model_id: Optional[str] = None + output_dir: str = None + logging_dir: Optional[str] = "logs" + allow_tf32: bool = False + nccl_timeout: int = 1800 # 30 minutes + report_to: str = "wandb" + + def to_dict(self) -> Dict[str, Any]: + return { + "model_arguments": { + "model_name": self.model_name, + "pretrained_model_name_or_path": self.pretrained_model_name_or_path, + "revision": self.revision, + "variant": self.variant, + "cache_dir": self.cache_dir, + "text_encoder_dtype": self.text_encoder_dtype, + "text_encoder_2_dtype": self.text_encoder_2_dtype, + "text_encoder_3_dtype": self.text_encoder_3_dtype, + "transformer_dtype": self.transformer_dtype, + "vae_dtype": self.vae_dtype, + "layerwise_upcasting_modules": self.layerwise_upcasting_modules, + "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype, + "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern, + }, + "dataset_arguments": { + "data_root": self.data_root, + "dataset_file": self.dataset_file, + "video_column": self.video_column, + "caption_column": self.caption_column, + "id_token": self.id_token, + "image_resolution_buckets": self.image_resolution_buckets, + "video_resolution_buckets": self.video_resolution_buckets, + "video_reshape_mode": self.video_reshape_mode, + "caption_dropout_p": self.caption_dropout_p, + "caption_dropout_technique": self.caption_dropout_technique, + "precompute_conditions": self.precompute_conditions, + "remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes, + }, + "dataloader_arguments": { + "dataloader_num_workers": self.dataloader_num_workers, + "pin_memory": self.pin_memory, + }, + "diffusion_arguments": { + "flow_resolution_shifting": self.flow_resolution_shifting, + "flow_base_seq_len": self.flow_base_seq_len, + "flow_max_seq_len": self.flow_max_seq_len, + "flow_base_shift": self.flow_base_shift, + "flow_max_shift": self.flow_max_shift, + "flow_shift": self.flow_shift, + "flow_weighting_scheme": self.flow_weighting_scheme, + "flow_logit_mean": self.flow_logit_mean, + "flow_logit_std": self.flow_logit_std, + "flow_mode_scale": self.flow_mode_scale, + }, + "training_arguments": { + "training_type": self.training_type, + "seed": self.seed, + "batch_size": self.batch_size, + "train_epochs": self.train_epochs, + "train_steps": self.train_steps, + "rank": self.rank, + "lora_alpha": self.lora_alpha, + "target_modules": self.target_modules, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "gradient_checkpointing": self.gradient_checkpointing, + "checkpointing_steps": self.checkpointing_steps, + "checkpointing_limit": self.checkpointing_limit, + "resume_from_checkpoint": self.resume_from_checkpoint, + "enable_slicing": self.enable_slicing, + "enable_tiling": self.enable_tiling, + }, + "optimizer_arguments": { + "optimizer": self.optimizer, + "use_8bit_bnb": self.use_8bit_bnb, + "lr": self.lr, + "scale_lr": self.scale_lr, + "lr_scheduler": self.lr_scheduler, + "lr_warmup_steps": self.lr_warmup_steps, + "lr_num_cycles": self.lr_num_cycles, + "lr_power": self.lr_power, + "beta1": self.beta1, + "beta2": self.beta2, + "beta3": self.beta3, + "weight_decay": self.weight_decay, + "epsilon": self.epsilon, + "max_grad_norm": self.max_grad_norm, + }, + "validation_arguments": { + "validation_prompts": self.validation_prompts, + "validation_images": self.validation_images, + "validation_videos": self.validation_videos, + "num_validation_videos_per_prompt": self.num_validation_videos_per_prompt, + "validation_every_n_epochs": self.validation_every_n_epochs, + "validation_every_n_steps": self.validation_every_n_steps, + "enable_model_cpu_offload": self.enable_model_cpu_offload, + "validation_frame_rate": self.validation_frame_rate, + }, + "miscellaneous_arguments": { + "tracker_name": self.tracker_name, + "push_to_hub": self.push_to_hub, + "hub_token": self.hub_token, + "hub_model_id": self.hub_model_id, + "output_dir": self.output_dir, + "logging_dir": self.logging_dir, + "allow_tf32": self.allow_tf32, + "nccl_timeout": self.nccl_timeout, + "report_to": self.report_to, + }, + } + + +# TODO(aryan): handle more informative messages +_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv + + +def parse_arguments() -> Args: + parser = argparse.ArgumentParser() + + if _IS_ARGUMENTS_REQUIRED: + _add_model_arguments(parser) + _add_dataset_arguments(parser) + _add_dataloader_arguments(parser) + _add_diffusion_arguments(parser) + _add_training_arguments(parser) + _add_optimizer_arguments(parser) + _add_validation_arguments(parser) + _add_miscellaneous_arguments(parser) + + args = parser.parse_args() + return _map_to_args_type(args) + else: + _add_helper_arguments(parser) + + args = parser.parse_args() + _display_helper_messages(args) + sys.exit(0) + + +def validate_args(args: Args): + _validated_model_args(args) + _validate_training_args(args) + _validate_validation_args(args) + + +def _add_model_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--model_name", + type=str, + required=True, + choices=list(SUPPORTED_MODEL_CONFIGS.keys()), + help="Name of model to train.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.") + parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.") + parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.") + parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.") + parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.") + parser.add_argument( + "--layerwise_upcasting_modules", + type=str, + default=[], + nargs="+", + choices=["transformer"], + help="Modules that should have fp8 storage weights but higher precision computation.", + ) + parser.add_argument( + "--layerwise_upcasting_storage_dtype", + type=str, + default="float8_e4m3fn", + choices=["float8_e4m3fn", "float8_e5m2"], + help="Data type for the layerwise upcasting storage.", + ) + parser.add_argument( + "--layerwise_upcasting_skip_modules_pattern", + type=str, + default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"], + nargs="+", + help="Modules to skip for layerwise upcasting.", + ) + + +def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None: + def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]: + return tuple(map(int, resolution_bucket.split("x"))) + + def parse_image_resolution_bucket(resolution_bucket: str) -> Tuple[int, int]: + resolution_bucket = parse_resolution_bucket(resolution_bucket) + assert ( + len(resolution_bucket) == 2 + ), f"Expected 2D resolution bucket, got {len(resolution_bucket)}D resolution bucket" + return resolution_bucket + + def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int]: + resolution_bucket = parse_resolution_bucket(resolution_bucket) + assert ( + len(resolution_bucket) == 3 + ), f"Expected 3D resolution bucket, got {len(resolution_bucket)}D resolution bucket" + return resolution_bucket + + parser.add_argument( + "--data_root", + type=str, + required=True, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--dataset_file", + type=str, + default=None, + help=("Path to a CSV file if loading prompts/video paths using this format."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--image_resolution_buckets", + type=parse_image_resolution_bucket, + default=None, + nargs="+", + help="Resolution buckets for images.", + ) + parser.add_argument( + "--video_resolution_buckets", + type=parse_video_resolution_bucket, + default=None, + nargs="+", + help="Resolution buckets for videos.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument( + "--caption_dropout_p", + type=float, + default=0.00, + help="Probability of dropout for the caption tokens.", + ) + parser.add_argument( + "--caption_dropout_technique", + type=str, + default="empty", + choices=["empty", "zero"], + help="Technique to use for caption dropout.", + ) + parser.add_argument( + "--precompute_conditions", + action="store_true", + help="Whether or not to precompute the conditionings for the model.", + ) + parser.add_argument( + "--remove_common_llm_caption_prefixes", + action="store_true", + help="Whether or not to remove common LLM caption prefixes.", + ) + + +def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + + +def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--flow_resolution_shifting", + action="store_true", + help="Resolution-dependent shifting of timestep schedules.", + ) + parser.add_argument( + "--flow_base_seq_len", + type=int, + default=256, + help="Base image/video sequence length for the diffusion model.", + ) + parser.add_argument( + "--flow_max_seq_len", + type=int, + default=4096, + help="Maximum image/video sequence length for the diffusion model.", + ) + parser.add_argument( + "--flow_base_shift", + type=float, + default=0.5, + help="Base shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)", + ) + parser.add_argument( + "--flow_max_shift", + type=float, + default=1.15, + help="Maximum shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)", + ) + parser.add_argument( + "--flow_shift", + type=float, + default=1.0, + help="Shift value to use for the flow matching timestep schedule.", + ) + parser.add_argument( + "--flow_weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help='We default to the "none" weighting scheme for uniform sampling and uniform loss', + ) + parser.add_argument( + "--flow_logit_mean", + type=float, + default=0.0, + help="Mean to use when using the `'logit_normal'` weighting scheme.", + ) + parser.add_argument( + "--flow_logit_std", + type=float, + default=1.0, + help="Standard deviation to use when using the `'logit_normal'` weighting scheme.", + ) + parser.add_argument( + "--flow_mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + +def _add_training_arguments(parser: argparse.ArgumentParser) -> None: + # TODO: support full finetuning and other kinds + parser.add_argument( + "--training_type", + type=str, + choices=["lora", "full-finetune"], + required=True, + help="Type of training to perform. Choose between ['lora', 'full-finetune']", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.") + parser.add_argument( + "--train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", + ) + parser.add_argument( + "--target_modules", + type=str, + default=["to_k", "to_q", "to_v", "to_out.0"], + nargs="+", + help="The target modules for LoRA.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpointing_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_slicing", + action="store_true", + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + help="Whether or not to use VAE tiling for saving memory.", + ) + + +def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--lr", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_bnb", + action="store_true", + help=("Whether to use 8bit variant of the `--optimizer` using `bitsandbytes`."), + ) + parser.add_argument( + "--beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta2", + type=float, + default=0.95, + help="The beta2 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for optimizer.", + ) + parser.add_argument( + "--epsilon", + type=float, + default=1e-8, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + + +def _add_validation_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_videos", + type=str, + default=None, + help="One or more video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=None, + help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--validation_frame_rate", + type=int, + default=25, + help="Frame rate to use for the validation videos.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", + ) + + +def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default="finetrainers", help="Project tracker name") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="finetrainers-training", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--nccl_timeout", + type=int, + default=600, + help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", + ) + parser.add_argument( + "--report_to", + type=str, + default="none", + choices=["none", "wandb"], + help="The integration to report the results and logs to.", + ) + + +def _add_helper_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--list_models", + action="store_true", + help="List all the supported models.", + ) + + +_DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, +} + + +def _map_to_args_type(args: Dict[str, Any]) -> Args: + result_args = Args() + + # Model arguments + result_args.model_name = args.model_name + result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path + result_args.revision = args.revision + result_args.variant = args.variant + result_args.cache_dir = args.cache_dir + result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype] + result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype] + result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype] + result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype] + result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype] + result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules + result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype] + result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern + + # Dataset arguments + if args.data_root is None and args.dataset_file is None: + raise ValueError("At least one of `data_root` or `dataset_file` should be provided.") + + result_args.data_root = args.data_root + result_args.dataset_file = args.dataset_file + result_args.video_column = args.video_column + result_args.caption_column = args.caption_column + result_args.id_token = args.id_token + result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS + result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS + result_args.video_reshape_mode = args.video_reshape_mode + result_args.caption_dropout_p = args.caption_dropout_p + result_args.caption_dropout_technique = args.caption_dropout_technique + result_args.precompute_conditions = args.precompute_conditions + result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes + + # Dataloader arguments + result_args.dataloader_num_workers = args.dataloader_num_workers + result_args.pin_memory = args.pin_memory + + # Diffusion arguments + result_args.flow_resolution_shifting = args.flow_resolution_shifting + result_args.flow_base_seq_len = args.flow_base_seq_len + result_args.flow_max_seq_len = args.flow_max_seq_len + result_args.flow_base_shift = args.flow_base_shift + result_args.flow_max_shift = args.flow_max_shift + result_args.flow_shift = args.flow_shift + result_args.flow_weighting_scheme = args.flow_weighting_scheme + result_args.flow_logit_mean = args.flow_logit_mean + result_args.flow_logit_std = args.flow_logit_std + result_args.flow_mode_scale = args.flow_mode_scale + + # Training arguments + result_args.training_type = args.training_type + result_args.seed = args.seed + result_args.batch_size = args.batch_size + result_args.train_epochs = args.train_epochs + result_args.train_steps = args.train_steps + result_args.rank = args.rank + result_args.lora_alpha = args.lora_alpha + result_args.target_modules = args.target_modules + result_args.gradient_accumulation_steps = args.gradient_accumulation_steps + result_args.gradient_checkpointing = args.gradient_checkpointing + result_args.checkpointing_steps = args.checkpointing_steps + result_args.checkpointing_limit = args.checkpointing_limit + result_args.resume_from_checkpoint = args.resume_from_checkpoint + result_args.enable_slicing = args.enable_slicing + result_args.enable_tiling = args.enable_tiling + + # Optimizer arguments + result_args.optimizer = args.optimizer or "adamw" + result_args.use_8bit_bnb = args.use_8bit_bnb + result_args.lr = args.lr or 1e-4 + result_args.scale_lr = args.scale_lr + result_args.lr_scheduler = args.lr_scheduler + result_args.lr_warmup_steps = args.lr_warmup_steps + result_args.lr_num_cycles = args.lr_num_cycles + result_args.lr_power = args.lr_power + result_args.beta1 = args.beta1 + result_args.beta2 = args.beta2 + result_args.beta3 = args.beta3 + result_args.weight_decay = args.weight_decay + result_args.epsilon = args.epsilon + result_args.max_grad_norm = args.max_grad_norm + + # Validation arguments + validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else [] + validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None + validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None + stripped_validation_prompts = [] + validation_heights = [] + validation_widths = [] + validation_num_frames = [] + for prompt in validation_prompts: + prompt: str + prompt = prompt.strip() + actual_prompt, separator, resolution = prompt.rpartition("@@@") + stripped_validation_prompts.append(actual_prompt) + num_frames, height, width = None, None, None + if len(resolution) > 0: + num_frames, height, width = map(int, resolution.split("x")) + validation_num_frames.append(num_frames) + validation_heights.append(height) + validation_widths.append(width) + + if validation_images is None: + validation_images = [None] * len(validation_prompts) + if validation_videos is None: + validation_videos = [None] * len(validation_prompts) + + result_args.validation_prompts = stripped_validation_prompts + result_args.validation_heights = validation_heights + result_args.validation_widths = validation_widths + result_args.validation_num_frames = validation_num_frames + result_args.validation_images = validation_images + result_args.validation_videos = validation_videos + + result_args.num_validation_videos_per_prompt = args.num_validation_videos + result_args.validation_every_n_epochs = args.validation_epochs + result_args.validation_every_n_steps = args.validation_steps + result_args.enable_model_cpu_offload = args.enable_model_cpu_offload + result_args.validation_frame_rate = args.validation_frame_rate + + # Miscellaneous arguments + result_args.tracker_name = args.tracker_name + result_args.push_to_hub = args.push_to_hub + result_args.hub_token = args.hub_token + result_args.hub_model_id = args.hub_model_id + result_args.output_dir = args.output_dir + result_args.logging_dir = args.logging_dir + result_args.allow_tf32 = args.allow_tf32 + result_args.nccl_timeout = args.nccl_timeout + result_args.report_to = args.report_to + + return result_args + + +def _validated_model_args(args: Args): + if args.training_type == "full-finetune": + assert ( + "transformer" not in args.layerwise_upcasting_modules + ), "Layerwise upcasting is not supported for full-finetune training" + + +def _validate_training_args(args: Args): + if args.training_type == "lora": + assert args.rank is not None, "Rank is required for LoRA training" + assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training" + assert ( + args.target_modules is not None and len(args.target_modules) > 0 + ), "Target modules are required for LoRA training" + + +def _validate_validation_args(args: Args): + assert args.validation_prompts is not None, "Validation prompts are required for validation" + if args.validation_images is not None: + assert len(args.validation_images) == len( + args.validation_prompts + ), "Validation images and prompts should be of same length" + if args.validation_videos is not None: + assert len(args.validation_videos) == len( + args.validation_prompts + ), "Validation videos and prompts should be of same length" + assert len(args.validation_prompts) == len( + args.validation_heights + ), "Validation prompts and heights should be of same length" + assert len(args.validation_prompts) == len( + args.validation_widths + ), "Validation prompts and widths should be of same length" + + +def _display_helper_messages(args: argparse.Namespace): + if args.list_models: + print("Supported models:") + for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()): + print(f" {index + 1}. {model_name}") diff --git a/finetrainers/constants.py b/finetrainers/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f6318f4cc43f96db9adf94b22ae92684f976f6fe --- /dev/null +++ b/finetrainers/constants.py @@ -0,0 +1,80 @@ +import os + + +DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +DEFAULT_FRAME_BUCKETS = [49] + +DEFAULT_IMAGE_RESOLUTION_BUCKETS = [] +for height in DEFAULT_HEIGHT_BUCKETS: + for width in DEFAULT_WIDTH_BUCKETS: + DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width)) + +DEFAULT_VIDEO_RESOLUTION_BUCKETS = [] +for frames in DEFAULT_FRAME_BUCKETS: + for height in DEFAULT_HEIGHT_BUCKETS: + for width in DEFAULT_WIDTH_BUCKETS: + DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width)) + + +FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO") + +PRECOMPUTED_DIR_NAME = "precomputed" +PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions" +PRECOMPUTED_LATENTS_DIR_NAME = "latents" + +MODEL_DESCRIPTION = r""" +\# {model_id} {training_type} finetune + + + +\#\# Model Description + +This model is a {training_type} of the `{model_id}` model. + +This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers). + +\#\# Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +\#\# Usage + +Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed. + +```python +{model_example} +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +\#\# License + +Please adhere to the license of the base model. +""".strip() + +_COMMON_BEGINNING_PHRASES = ( + "This video", + "The video", + "This clip", + "The clip", + "The animation", + "This image", + "The image", + "This picture", + "The picture", +) +_COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents") + +COMMON_LLM_START_PHRASES = ( + "In the video,", + "In this video,", + "In this video clip,", + "In the clip,", + "Caption:", + *( + f"{beginning} {continuation}" + for beginning in _COMMON_BEGINNING_PHRASES + for continuation in _COMMON_CONTINUATION_WORDS + ), +) diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cb364bd31113a695d8fa2714db9112ec4489d288 --- /dev/null +++ b/finetrainers/dataset.py @@ -0,0 +1,467 @@ +import json +import os +import random +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torchvision.transforms as TT +import torchvision.transforms.functional as TTF +from accelerate.logging import get_logger +from torch.utils.data import Dataset, Sampler +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize + + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +from .constants import ( # noqa + COMMON_LLM_START_PHRASES, + PRECOMPUTED_CONDITIONS_DIR_NAME, + PRECOMPUTED_DIR_NAME, + PRECOMPUTED_LATENTS_DIR_NAME, +) + + +logger = get_logger(__name__) + + +# TODO(aryan): This needs a refactor with separation of concerns. +# Images should be handled separately. Videos should be handled separately. +# Loading should be handled separately. +# Preprocessing (aspect ratio, resizing) should be handled separately. +# URL loading should be handled. +# Parquet format should be handled. +# Loading from ZIP should be handled. +class ImageOrVideoDataset(Dataset): + def __init__( + self, + data_root: str, + caption_column: str, + video_column: str, + resolution_buckets: List[Tuple[int, int, int]], + dataset_file: Optional[str] = None, + id_token: Optional[str] = None, + remove_llm_prefixes: bool = False, + ) -> None: + super().__init__() + + self.data_root = Path(data_root) + self.dataset_file = dataset_file + self.caption_column = caption_column + self.video_column = video_column + self.id_token = f"{id_token.strip()} " if id_token else "" + self.resolution_buckets = resolution_buckets + + # Four methods of loading data are supported. + # - Using a CSV: caption_column and video_column must be some column in the CSV. One could + # make use of other columns too, such as a motion score or aesthetic score, by modifying the + # logic in CSV processing. + # - Using two files containing line-separate captions and relative paths to videos. + # - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key. + # - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key. + # For a more detailed explanation about preparing dataset format, checkout the README. + if dataset_file is None: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_local_path() + elif dataset_file.endswith(".csv"): + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_csv() + elif dataset_file.endswith(".json"): + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_json() + elif dataset_file.endswith(".jsonl"): + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_jsonl() + else: + raise ValueError( + "Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths." + ) + + if len(self.video_paths) != len(self.prompts): + raise ValueError( + f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + # Clean LLM start phrases + if remove_llm_prefixes: + for i in range(len(self.prompts)): + self.prompts[i] = self.prompts[i].strip() + for phrase in COMMON_LLM_START_PHRASES: + if self.prompts[i].startswith(phrase): + self.prompts[i] = self.prompts[i].removeprefix(phrase).strip() + + self.video_transforms = transforms.Compose( + [ + transforms.Lambda(self.scale_transform), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + @staticmethod + def scale_transform(x): + return x / 255.0 + + def __len__(self) -> int: + return len(self.video_paths) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + prompt = self.id_token + self.prompts[index] + + video_path: Path = self.video_paths[index] + if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: + video = self._preprocess_image(video_path) + else: + video = self._preprocess_video(video_path) + + return { + "prompt": prompt, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } + + def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: + if not self.data_root.exists(): + raise ValueError("Root folder for videos does not exist") + + prompt_path = self.data_root.joinpath(self.caption_column) + video_path = self.data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: + df = pd.read_csv(self.dataset_file) + prompts = df[self.caption_column].tolist() + video_paths = df[self.video_column].tolist() + video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]: + with open(self.dataset_file, "r", encoding="utf-8") as file: + data = json.load(file) + + prompts = [entry[self.caption_column] for entry in data] + video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]: + with open(self.dataset_file, "r", encoding="utf-8") as file: + data = [json.loads(line) for line in file] + + prompts = [entry[self.caption_column] for entry in data] + video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _preprocess_image(self, path: Path) -> torch.Tensor: + # TODO(aryan): Support alpha channel in future by whitening background + image = TTF.Image.open(path.as_posix()).convert("RGB") + image = TTF.to_tensor(image) + image = image * 2.0 - 1.0 + image = image.unsqueeze(0).contiguous() # [C, H, W] -> [1, C, H, W] (1-frame video) + return image + + def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + r""" + Loads a single video, or latent and prompt embedding, based on initialization parameters. + + Returns a [F, C, H, W] video tensor. + """ + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + + indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[: self.max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) + return frames + + +class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] + + def _preprocess_image(self, path: Path) -> torch.Tensor: + # TODO(aryan): Support alpha channel in future by whitening background + image = TTF.Image.open(path.as_posix()).convert("RGB") + image = TTF.to_tensor(image) + + nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2]) + image = resize(image, nearest_res) + + image = image * 2.0 - 1.0 + image = image.unsqueeze(0).contiguous() + return image + + def _preprocess_video(self, path: Path) -> torch.Tensor: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets) + print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames) + print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames) + + video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames] + + if not video_buckets: + _, h, w = self.resolution_buckets[0] + video_buckets = [(1, h, w)] + + nearest_frame_bucket = min( + video_buckets, + key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)), + default=video_buckets[0], + )[0] + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + return frames + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset): + def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.video_reshape_mode = video_reshape_mode + self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def _preprocess_video(self, path: Path) -> torch.Tensor: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.resolution_buckets = ", self.resolution_buckets) + print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.max_num_frames = ", self.max_num_frames) + print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: video_num_frames = ", video_num_frames) + + video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames] + + if not video_buckets: + _, h, w = self.resolution_buckets[0] + video_buckets = [(1, h, w)] + + nearest_frame_bucket = min( + video_buckets, + key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)), + default=video_buckets[0], + )[0] + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + return frames + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class PrecomputedDataset(Dataset): + def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None: + super().__init__() + + self.data_root = Path(data_root) + + if model_name and cleaned_model_id: + precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" + self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + else: + self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME + self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME + + self.latent_conditions = sorted(os.listdir(self.latents_path)) + self.text_conditions = sorted(os.listdir(self.conditions_path)) + + assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match" + + def __len__(self) -> int: + return len(self.latent_conditions) + + def __getitem__(self, index: int) -> Dict[str, Any]: + conditions = {} + latent_path = self.latents_path / self.latent_conditions[index] + condition_path = self.conditions_path / self.text_conditions[index] + conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True) + conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True) + return conditions + + +class BucketSampler(Sampler): + r""" + PyTorch Sampler that groups 3D data by height, width and frames. + + Args: + data_source (`ImageOrVideoDataset`): + A PyTorch dataset object that is an instance of `ImageOrVideoDataset`. + batch_size (`int`, defaults to `8`): + The batch size to use for training. + shuffle (`bool`, defaults to `True`): + Whether or not to shuffle the data in each batch before dispatching to dataloader. + drop_last (`bool`, defaults to `False`): + Whether or not to drop incomplete buckets of data after completely iterating over all data + in the dataset. If set to True, only batches that have `batch_size` number of entries will + be yielded. If set to False, it is guaranteed that all data in the dataset will be processed + and batches that do not have `batch_size` number of entries will also be yielded. + """ + + def __init__( + self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + ) -> None: + self.data_source = data_source + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + self.buckets = {resolution: [] for resolution in data_source.resolution_buckets} + + self._raised_warning_for_drop_last = False + + def __len__(self): + if self.drop_last and not self._raised_warning_for_drop_last: + self._raised_warning_for_drop_last = True + logger.warning( + "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." + ) + return (len(self.data_source) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + for index, data in enumerate(self.data_source): + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] + + if self.drop_last: + return + + for fhw, bucket in list(self.buckets.items()): + if len(bucket) == 0: + continue + if self.shuffle: + random.shuffle(bucket) + yield bucket + del self.buckets[fhw] + self.buckets[fhw] = [] diff --git a/finetrainers/hooks/__init__.py b/finetrainers/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c3a432f4021ec9b2666b48047c6fd40f3849b9 --- /dev/null +++ b/finetrainers/hooks/__init__.py @@ -0,0 +1 @@ +from .layerwise_upcasting import apply_layerwise_upcasting diff --git a/finetrainers/hooks/hooks.py b/finetrainers/hooks/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..e779795279e2302de286096563538d2beb818bac --- /dev/null +++ b/finetrainers/hooks/hooks.py @@ -0,0 +1,176 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Optional, Tuple + +import torch +from accelerate.logging import get_logger + +from ..constants import FINETRAINERS_LOG_LEVEL + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is deinitalized. + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + module.forward = module._old_forward + del module._old_forward + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class HookRegistry: + def __init__(self, module_ref: torch.nn.Module) -> None: + super().__init__() + + self.hooks: Dict[str, ModelHook] = {} + + self._module_ref = module_ref + self._hook_order = [] + + def register_hook(self, hook: ModelHook, name: str) -> None: + if name in self.hooks.keys(): + logger.warning(f"Hook with name {name} already exists, replacing it.") + + if hasattr(self._module_ref, "_old_forward"): + old_forward = self._module_ref._old_forward + else: + old_forward = self._module_ref.forward + self._module_ref._old_forward = self._module_ref.forward + + self._module_ref = hook.initialize_hook(self._module_ref) + + if hasattr(hook, "new_forward"): + rewritten_forward = hook.new_forward + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = rewritten_forward(module, *args, **kwargs) + return hook.post_forward(module, output) + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = old_forward(*args, **kwargs) + return hook.post_forward(module, output) + + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) + + self.hooks[name] = hook + self._hook_order.append(name) + + def get_hook(self, name: str) -> Optional[ModelHook]: + if name not in self.hooks.keys(): + return None + return self.hooks[name] + + def remove_hook(self, name: str) -> None: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + self.hooks[name].deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + def reset_stateful_hooks(self, recurse: bool = True) -> None: + for hook_name in self._hook_order: + hook = self.hooks[hook_name] + if hook._is_stateful: + hook.reset_state(self._module_ref) + + if recurse: + for module in self._module_ref.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.reset_stateful_hooks(recurse=False) + + @classmethod + def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": + if not hasattr(module, "_diffusers_hook"): + module._diffusers_hook = cls(module) + return module._diffusers_hook + + def __repr__(self) -> str: + hook_repr = "" + for i, hook_name in enumerate(self._hook_order): + hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if i < len(self._hook_order) - 1: + hook_repr += "\n" + return f"HookRegistry(\n{hook_repr}\n)" diff --git a/finetrainers/hooks/layerwise_upcasting.py b/finetrainers/hooks/layerwise_upcasting.py new file mode 100644 index 0000000000000000000000000000000000000000..b7bdc38021c5145a2a6ac515270dc356385d03a7 --- /dev/null +++ b/finetrainers/hooks/layerwise_upcasting.py @@ -0,0 +1,140 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Optional, Tuple, Type + +import torch +from accelerate.logging import get_logger + +from ..constants import FINETRAINERS_LOG_LEVEL +from .hooks import HookRegistry, ModelHook + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +# fmt: off +_SUPPORTED_PYTORCH_LAYERS = ( + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +) + +_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm") +# fmt: on + + +class LayerwiseUpcastingHook(ModelHook): + r""" + A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype + for storage. This process may lead to quality loss in the output, but can significantly reduce the memory + footprint. + """ + + _is_stateful = False + + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None: + self.storage_dtype = storage_dtype + self.compute_dtype = compute_dtype + self.non_blocking = non_blocking + + def initialize_hook(self, module: torch.nn.Module): + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) + return output + + +def apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None, + non_blocking: bool = False, + _prefix: str = "", +) -> None: + r""" + Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any + nn.Module using diffusers layers or pytorch primitives. + Args: + module (`torch.nn.Module`): + The module whose leaf modules will be cast to a high precision dtype for computation, and to a low + precision dtype for storage. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before/after the forward pass for storage. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass for computation. + skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`): + A list of module classes to skip during the layerwise upcasting process. + non_blocking (`bool`, defaults to `False`): + If `True`, the weight casting operations are non-blocking. + """ + if skip_modules_classes is None and skip_modules_pattern is None: + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( + skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) + ) + if should_skip: + logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"') + return + + if isinstance(module, _SUPPORTED_PYTORCH_LAYERS): + logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"') + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + for name, submodule in module.named_children(): + layer_name = f"{_prefix}.{name}" if _prefix else name + apply_layerwise_upcasting( + submodule, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + _prefix=layer_name, + ) + + +def apply_layerwise_upcasting_hook( + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool +) -> None: + r""" + Applies a `LayerwiseUpcastingHook` to a given module. + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + non_blocking (`bool`): + If `True`, the weight casting operations are non-blocking. + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking) + registry.register_hook(hook, "layerwise_upcasting") diff --git a/finetrainers/models/__init__.py b/finetrainers/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c24ab951d8b3cd8e52dfa0b3647c7d8183c1352c --- /dev/null +++ b/finetrainers/models/__init__.py @@ -0,0 +1,33 @@ +from typing import Any, Dict + +from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG +from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG +from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG + + +SUPPORTED_MODEL_CONFIGS = { + "hunyuan_video": { + "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG, + "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, + }, + "ltx_video": { + "lora": LTX_VIDEO_T2V_LORA_CONFIG, + "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, + }, + "cogvideox": { + "lora": COGVIDEOX_T2V_LORA_CONFIG, + "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, + }, +} + + +def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]: + if model_name not in SUPPORTED_MODEL_CONFIGS: + raise ValueError( + f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" + ) + if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: + raise ValueError( + f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" + ) + return SUPPORTED_MODEL_CONFIGS[model_name][training_type] diff --git a/finetrainers/models/cogvideox/__init__.py b/finetrainers/models/cogvideox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a72064347e083b0d277437e6a4e6e2e54164277 --- /dev/null +++ b/finetrainers/models/cogvideox/__init__.py @@ -0,0 +1,2 @@ +from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG +from .lora import COGVIDEOX_T2V_LORA_CONFIG diff --git a/finetrainers/models/cogvideox/full_finetune.py b/finetrainers/models/cogvideox/full_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f2b4bbcff806c7c12b736c36bb9733b9980353 --- /dev/null +++ b/finetrainers/models/cogvideox/full_finetune.py @@ -0,0 +1,32 @@ +from diffusers import CogVideoXPipeline + +from .lora import ( + calculate_noisy_latents, + collate_fn_t2v, + forward_pass, + initialize_pipeline, + load_condition_models, + load_diffusion_models, + load_latent_models, + post_latent_preparation, + prepare_conditions, + prepare_latents, + validation, +) + + +# TODO(aryan): refactor into model specs for better re-use +COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = { + "pipeline_cls": CogVideoXPipeline, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, + "initialize_pipeline": initialize_pipeline, + "prepare_conditions": prepare_conditions, + "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, + "collate_fn": collate_fn_t2v, + "calculate_noisy_latents": calculate_noisy_latents, + "forward_pass": forward_pass, + "validation": validation, +} diff --git a/finetrainers/models/cogvideox/lora.py b/finetrainers/models/cogvideox/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..65d86ee901d73296c94c2abe20f21293cace45b3 --- /dev/null +++ b/finetrainers/models/cogvideox/lora.py @@ -0,0 +1,334 @@ +from typing import Any, Dict, List, Optional, Union + +import torch +from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from .utils import prepare_rotary_positional_embeddings + + +def load_condition_models( + model_id: str = "THUDM/CogVideoX-5b", + text_encoder_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +): + tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) + text_encoder = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir + ) + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + +def load_latent_models( + model_id: str = "THUDM/CogVideoX-5b", + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +): + vae = AutoencoderKLCogVideoX.from_pretrained( + model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir + ) + return {"vae": vae} + + +def load_diffusion_models( + model_id: str = "THUDM/CogVideoX-5b", + transformer_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +): + transformer = CogVideoXTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir + ) + scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler") + return {"transformer": transformer, "scheduler": scheduler} + + +def initialize_pipeline( + model_id: str = "THUDM/CogVideoX-5b", + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + transformer: Optional[CogVideoXTransformer3DModel] = None, + vae: Optional[AutoencoderKLCogVideoX] = None, + scheduler: Optional[CogVideoXDDIMScheduler] = None, + device: Optional[torch.device] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + is_training: bool = False, + **kwargs, +) -> CogVideoXPipeline: + component_name_pairs = [ + ("tokenizer", tokenizer), + ("text_encoder", text_encoder), + ("transformer", transformer), + ("vae", vae), + ("scheduler", scheduler), + ] + components = {} + for name, component in component_name_pairs: + if component is not None: + components[name] = component + + pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) + pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) + pipe.vae = pipe.vae.to(dtype=vae_dtype) + + # The transformer should already be in the correct dtype when training, so we don't need to cast it here. + # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during + # DDP optimizer step. + if not is_training: + pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload(device=device) + else: + pipe.to(device=device) + + return pipe + + +def prepare_conditions( + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 226, # TODO: this should be configurable + **kwargs, +): + device = device or text_encoder.device + dtype = dtype or text_encoder.dtype + return _get_t5_prompt_embeds( + tokenizer=tokenizer, + text_encoder=text_encoder, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + +def prepare_latents( + vae: AutoencoderKLCogVideoX, + image_or_video: torch.Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + generator: Optional[torch.Generator] = None, + precompute: bool = False, + **kwargs, +) -> torch.Tensor: + device = device or vae.device + dtype = dtype or vae.dtype + + if image_or_video.ndim == 4: + image_or_video = image_or_video.unsqueeze(2) + assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" + + image_or_video = image_or_video.to(device=device, dtype=vae.dtype) + image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + if not precompute: + latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) + if not vae.config.invert_scale_latents: + latents = latents * vae.config.scaling_factor + # For training Cog 1.5, we don't need to handle the scaling factor here. + # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents + # is probably only needed for image-to-video training. + # TODO(aryan): investigate this + # else: + # latents = 1 / vae.config.scaling_factor * latents + latents = latents.to(dtype=dtype) + return {"latents": latents} + else: + # handle vae scaling in the `train()` method directly. + if vae.use_slicing and image_or_video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] + h = torch.cat(encoded_slices) + else: + h = vae._encode(image_or_video) + return {"latents": h} + + +def post_latent_preparation( + vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs +) -> torch.Tensor: + if not vae_config.invert_scale_latents: + latents = latents * vae_config.scaling_factor + # For training Cog 1.5, we don't need to handle the scaling factor here. + # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents + # is probably only needed for image-to-video training. + # TODO(aryan): investigate this + # else: + # latents = 1 / vae_config.scaling_factor * latents + latents = _pad_frames(latents, patch_size_t) + latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + return {"latents": latents} + + +def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + return { + "prompts": [x["prompt"] for x in batch[0]], + "videos": torch.stack([x["video"] for x in batch[0]]), + } + + +def calculate_noisy_latents( + scheduler: CogVideoXDDIMScheduler, + noise: torch.Tensor, + latents: torch.Tensor, + timesteps: torch.LongTensor, +) -> torch.Tensor: + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + return noisy_latents + + +def forward_pass( + transformer: CogVideoXTransformer3DModel, + scheduler: CogVideoXDDIMScheduler, + prompt_embeds: torch.Tensor, + latents: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.LongTensor, + ofs_emb: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself. + VAE_SPATIAL_SCALE_FACTOR = 8 + transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + batch_size, num_frames, num_channels, height, width = noisy_latents.shape + rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR + rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR + + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SPATIAL_SCALE_FACTOR, + width=width * VAE_SPATIAL_SCALE_FACTOR, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR, + patch_size=transformer_config.patch_size, + patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None, + attention_head_dim=transformer_config.attention_head_dim, + device=transformer.device, + base_height=rope_base_height, + base_width=rope_base_width, + ) + if transformer_config.use_rotary_positional_embeddings + else None + ) + ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0) + + velocity = transformer( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same + # code paths as scheduler.get_velocity(), which can be confusing to understand. + denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps) + + return {"latents": denoised_latents} + + +def validation( + pipeline: CogVideoXPipeline, + prompt: str, + image: Optional[Image.Image] = None, + video: Optional[List[Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_videos_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + **kwargs, +): + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_videos_per_prompt": num_videos_per_prompt, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} + output = pipeline(**generation_kwargs).frames[0] + return [("video", output)] + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return {"prompt_embeds": prompt_embeds} + + +def _pad_frames(latents: torch.Tensor, patch_size_t: int): + if patch_size_t is None or patch_size_t == 1: + return latents + + # `latents` should be of the following format: [B, C, F, H, W]. + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + latent_num_frames = latents.shape[2] + additional_frames = patch_size_t - latent_num_frames % patch_size_t + + if additional_frames > 0: + last_frame = latents[:, :, -1:, :, :] + padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1) + latents = torch.cat([latents, padding_frames], dim=2) + + return latents + + +# TODO(aryan): refactor into model specs for better re-use +COGVIDEOX_T2V_LORA_CONFIG = { + "pipeline_cls": CogVideoXPipeline, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, + "initialize_pipeline": initialize_pipeline, + "prepare_conditions": prepare_conditions, + "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, + "collate_fn": collate_fn_t2v, + "calculate_noisy_latents": calculate_noisy_latents, + "forward_pass": forward_pass, + "validation": validation, +} diff --git a/finetrainers/models/cogvideox/utils.py b/finetrainers/models/cogvideox/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd98c1f3653dbe23a6f53fa54dfe3e7073ea9b99 --- /dev/null +++ b/finetrainers/models/cogvideox/utils.py @@ -0,0 +1,51 @@ +from typing import Optional, Tuple + +import torch +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + patch_size_t: int = None, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + if patch_size_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin diff --git a/finetrainers/models/hunyuan_video/__init__.py b/finetrainers/models/hunyuan_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac729e91bb0d8af781ea51e856a43bfff1990df --- /dev/null +++ b/finetrainers/models/hunyuan_video/__init__.py @@ -0,0 +1,2 @@ +from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG +from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG diff --git a/finetrainers/models/hunyuan_video/full_finetune.py b/finetrainers/models/hunyuan_video/full_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..65e73f5451cacbc4ea4540b6bfcd3c3bd1b9e531 --- /dev/null +++ b/finetrainers/models/hunyuan_video/full_finetune.py @@ -0,0 +1,30 @@ +from diffusers import HunyuanVideoPipeline + +from .lora import ( + collate_fn_t2v, + forward_pass, + initialize_pipeline, + load_condition_models, + load_diffusion_models, + load_latent_models, + post_latent_preparation, + prepare_conditions, + prepare_latents, + validation, +) + + +# TODO(aryan): refactor into model specs for better re-use +HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = { + "pipeline_cls": HunyuanVideoPipeline, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, + "initialize_pipeline": initialize_pipeline, + "prepare_conditions": prepare_conditions, + "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, + "collate_fn": collate_fn_t2v, + "forward_pass": forward_pass, + "validation": validation, +} diff --git a/finetrainers/models/hunyuan_video/lora.py b/finetrainers/models/hunyuan_video/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8ccd1f61f3131f9fb0c2ba1235070ce7439ba0 --- /dev/null +++ b/finetrainers/models/hunyuan_video/lora.py @@ -0,0 +1,368 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate.logging import get_logger +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from PIL import Image +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name + + +def load_condition_models( + model_id: str = "hunyuanvideo-community/HunyuanVideo", + text_encoder_dtype: torch.dtype = torch.float16, + text_encoder_2_dtype: torch.dtype = torch.float16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) + text_encoder = LlamaModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir + ) + text_encoder_2 = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir + ) + return { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + +def load_latent_models( + model_id: str = "hunyuanvideo-community/HunyuanVideo", + vae_dtype: torch.dtype = torch.float16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir + ) + return {"vae": vae} + + +def load_diffusion_models( + model_id: str = "hunyuanvideo-community/HunyuanVideo", + transformer_dtype: torch.dtype = torch.bfloat16, + shift: float = 1.0, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir + ) + scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) + return {"transformer": transformer, "scheduler": scheduler} + + +def initialize_pipeline( + model_id: str = "hunyuanvideo-community/HunyuanVideo", + text_encoder_dtype: torch.dtype = torch.float16, + text_encoder_2_dtype: torch.dtype = torch.float16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.float16, + tokenizer: Optional[LlamaTokenizer] = None, + text_encoder: Optional[LlamaModel] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder_2: Optional[CLIPTextModel] = None, + transformer: Optional[HunyuanVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLHunyuanVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + device: Optional[torch.device] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + is_training: bool = False, + **kwargs, +) -> HunyuanVideoPipeline: + component_name_pairs = [ + ("tokenizer", tokenizer), + ("text_encoder", text_encoder), + ("tokenizer_2", tokenizer_2), + ("text_encoder_2", text_encoder_2), + ("transformer", transformer), + ("vae", vae), + ("scheduler", scheduler), + ] + components = {} + for name, component in component_name_pairs: + if component is not None: + components[name] = component + + pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) + pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) + pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype) + pipe.vae = pipe.vae.to(dtype=vae_dtype) + + # The transformer should already be in the correct dtype when training, so we don't need to cast it here. + # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during + # DDP optimizer step. + if not is_training: + pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload(device=device) + else: + pipe.to(device=device) + + return pipe + + +def prepare_conditions( + tokenizer: LlamaTokenizer, + text_encoder: LlamaModel, + tokenizer_2: CLIPTokenizer, + text_encoder_2: CLIPTextModel, + prompt: Union[str, List[str]], + guidance: float = 1.0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + # TODO(aryan): make configurable + prompt_template: Dict[str, Any] = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, + }, + **kwargs, +) -> torch.Tensor: + device = device or text_encoder.device + dtype = dtype or text_encoder.dtype + + if isinstance(prompt, str): + prompt = [prompt] + + conditions = {} + conditions.update( + _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length) + ) + conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype)) + + guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0 + conditions["guidance"] = guidance + + return conditions + + +def prepare_latents( + vae: AutoencoderKLHunyuanVideo, + image_or_video: torch.Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + generator: Optional[torch.Generator] = None, + precompute: bool = False, + **kwargs, +) -> torch.Tensor: + device = device or vae.device + dtype = dtype or vae.dtype + + if image_or_video.ndim == 4: + image_or_video = image_or_video.unsqueeze(2) + assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" + + image_or_video = image_or_video.to(device=device, dtype=vae.dtype) + image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] + if not precompute: + latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) + latents = latents * vae.config.scaling_factor + latents = latents.to(dtype=dtype) + return {"latents": latents} + else: + if vae.use_slicing and image_or_video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] + h = torch.cat(encoded_slices) + else: + h = vae._encode(image_or_video) + return {"latents": h} + + +def post_latent_preparation( + vae_config: Dict[str, Any], + latents: torch.Tensor, + **kwargs, +) -> torch.Tensor: + latents = latents * vae_config.scaling_factor + return {"latents": latents} + + +def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + return { + "prompts": [x["prompt"] for x in batch[0]], + "videos": torch.stack([x["video"] for x in batch[0]]), + } + + +def forward_pass( + transformer: HunyuanVideoTransformer3DModel, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + prompt_attention_mask: torch.Tensor, + guidance: torch.Tensor, + latents: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.LongTensor, + **kwargs, +) -> torch.Tensor: + denoised_latents = transformer( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + return_dict=False, + )[0] + + return {"latents": denoised_latents} + + +def validation( + pipeline: HunyuanVideoPipeline, + prompt: str, + image: Optional[Image.Image] = None, + video: Optional[List[Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_videos_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + **kwargs, +): + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": 30, + "num_videos_per_prompt": num_videos_per_prompt, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} + output = pipeline(**generation_kwargs).frames[0] + return [("video", output)] + + +def _get_llama_prompt_embeds( + tokenizer: LlamaTokenizer, + text_encoder: LlamaModel, + prompt: List[str], + prompt_template: Dict[str, Any], + device: torch.device, + dtype: torch.dtype, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = len(prompt) + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} + + +def _get_clip_prompt_embeds( + tokenizer_2: CLIPTokenizer, + text_encoder_2: CLIPTextModel, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + max_sequence_length: int = 77, +) -> torch.Tensor: + text_inputs = tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype) + + return {"pooled_prompt_embeds": prompt_embeds} + + +# TODO(aryan): refactor into model specs for better re-use +HUNYUAN_VIDEO_T2V_LORA_CONFIG = { + "pipeline_cls": HunyuanVideoPipeline, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, + "initialize_pipeline": initialize_pipeline, + "prepare_conditions": prepare_conditions, + "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, + "collate_fn": collate_fn_t2v, + "forward_pass": forward_pass, + "validation": validation, +} diff --git a/finetrainers/models/ltx_video/__init__.py b/finetrainers/models/ltx_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69391cdf9157831343a5fb73b237de618e8288bd --- /dev/null +++ b/finetrainers/models/ltx_video/__init__.py @@ -0,0 +1,2 @@ +from .full_finetune import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG +from .lora import LTX_VIDEO_T2V_LORA_CONFIG diff --git a/finetrainers/models/ltx_video/full_finetune.py b/finetrainers/models/ltx_video/full_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..ca799ea6f1b4b075efa9f2c27bb69564832bcb7d --- /dev/null +++ b/finetrainers/models/ltx_video/full_finetune.py @@ -0,0 +1,30 @@ +from diffusers import LTXPipeline + +from .lora import ( + collate_fn_t2v, + forward_pass, + initialize_pipeline, + load_condition_models, + load_diffusion_models, + load_latent_models, + post_latent_preparation, + prepare_conditions, + prepare_latents, + validation, +) + + +# TODO(aryan): refactor into model specs for better re-use +LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG = { + "pipeline_cls": LTXPipeline, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, + "initialize_pipeline": initialize_pipeline, + "prepare_conditions": prepare_conditions, + "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, + "collate_fn": collate_fn_t2v, + "forward_pass": forward_pass, + "validation": validation, +} diff --git a/finetrainers/models/ltx_video/lora.py b/finetrainers/models/ltx_video/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd6ffa3e3b91564ff88222a2314f66c6e465116 --- /dev/null +++ b/finetrainers/models/ltx_video/lora.py @@ -0,0 +1,331 @@ +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from accelerate.logging import get_logger +from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name + + +def load_condition_models( + model_id: str = "Lightricks/LTX-Video", + text_encoder_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) + text_encoder = T5EncoderModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir + ) + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + +def load_latent_models( + model_id: str = "Lightricks/LTX-Video", + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + vae = AutoencoderKLLTXVideo.from_pretrained( + model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir + ) + return {"vae": vae} + + +def load_diffusion_models( + model_id: str = "Lightricks/LTX-Video", + transformer_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + transformer = LTXVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir + ) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} + + +def initialize_pipeline( + model_id: str = "Lightricks/LTX-Video", + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + transformer: Optional[LTXVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLLTXVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + device: Optional[torch.device] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + is_training: bool = False, + **kwargs, +) -> LTXPipeline: + component_name_pairs = [ + ("tokenizer", tokenizer), + ("text_encoder", text_encoder), + ("transformer", transformer), + ("vae", vae), + ("scheduler", scheduler), + ] + components = {} + for name, component in component_name_pairs: + if component is not None: + components[name] = component + + pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) + pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) + pipe.vae = pipe.vae.to(dtype=vae_dtype) + # The transformer should already be in the correct dtype when training, so we don't need to cast it here. + # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during + # DDP optimizer step. + if not is_training: + pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload(device=device) + else: + pipe.to(device=device) + + return pipe + + +def prepare_conditions( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 128, + **kwargs, +) -> torch.Tensor: + device = device or text_encoder.device + dtype = dtype or text_encoder.dtype + + if isinstance(prompt, str): + prompt = [prompt] + + return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length) + + +def prepare_latents( + vae: AutoencoderKLLTXVideo, + image_or_video: torch.Tensor, + patch_size: int = 1, + patch_size_t: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + generator: Optional[torch.Generator] = None, + precompute: bool = False, +) -> torch.Tensor: + device = device or vae.device + + if image_or_video.ndim == 4: + image_or_video = image_or_video.unsqueeze(2) + assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" + + image_or_video = image_or_video.to(device=device, dtype=vae.dtype) + image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] + if not precompute: + latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + _, _, num_frames, height, width = latents.shape + latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std) + latents = _pack_latents(latents, patch_size, patch_size_t) + return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} + else: + if vae.use_slicing and image_or_video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] + h = torch.cat(encoded_slices) + else: + h = vae._encode(image_or_video) + _, _, num_frames, height, width = h.shape + + # TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file + # if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored + # so as to reduce the disk memory requirements of the precomputed files. + return { + "latents": h, + "num_frames": num_frames, + "height": height, + "width": width, + "latents_mean": vae.latents_mean, + "latents_std": vae.latents_std, + } + + +def post_latent_preparation( + latents: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + **kwargs, +) -> torch.Tensor: + latents = _normalize_latents(latents, latents_mean, latents_std) + latents = _pack_latents(latents, patch_size, patch_size_t) + return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} + + +def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + return { + "prompts": [x["prompt"] for x in batch[0]], + "videos": torch.stack([x["video"] for x in batch[0]]), + } + + +def forward_pass( + transformer: LTXVideoTransformer3DModel, + prompt_embeds: torch.Tensor, + prompt_attention_mask: torch.Tensor, + latents: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.LongTensor, + num_frames: int, + height: int, + width: int, + **kwargs, +) -> torch.Tensor: + # TODO(aryan): make configurable + frame_rate = 25 + latent_frame_rate = frame_rate / 8 + spatial_compression_ratio = 32 + rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio] + + denoised_latents = transformer( + hidden_states=noisy_latents, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + encoder_attention_mask=prompt_attention_mask, + num_frames=num_frames, + height=height, + width=width, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + + return {"latents": denoised_latents} + + +def validation( + pipeline: LTXPipeline, + prompt: str, + image: Optional[Image.Image] = None, + video: Optional[List[Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + frame_rate: int = 24, + num_videos_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + **kwargs, +): + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "frame_rate": frame_rate, + "num_videos_per_prompt": num_videos_per_prompt, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} + video = pipeline(**generation_kwargs).frames[0] + return [("video", video)] + + +def _encode_prompt_t5( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: List[str], + device: torch.device, + dtype: torch.dtype, + max_sequence_length, +) -> torch.Tensor: + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} + + +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 +) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +LTX_VIDEO_T2V_LORA_CONFIG = { + "pipeline_cls": LTXPipeline, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, + "initialize_pipeline": initialize_pipeline, + "prepare_conditions": prepare_conditions, + "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, + "collate_fn": collate_fn_t2v, + "forward_pass": forward_pass, + "validation": validation, +} diff --git a/finetrainers/patches.py b/finetrainers/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..1faacbde58e2948a3aa83d8c848de2a9cc681583 --- /dev/null +++ b/finetrainers/patches.py @@ -0,0 +1,50 @@ +import functools + +import torch +from accelerate.logging import get_logger +from peft.tuners.tuners_utils import BaseTunerLayer + +from .constants import FINETRAINERS_LOG_LEVEL + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +def perform_peft_patches() -> None: + _perform_patch_move_adapter_to_device_of_base_layer() + + +def _perform_patch_move_adapter_to_device_of_base_layer() -> None: + # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights + # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of + # LoRA weights from higher precision dtype. + BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( + BaseTunerLayer._move_adapter_to_device_of_base_layer + ) + + +def _patched_move_adapter_to_device_of_base_layer(func) -> None: + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with DisableTensorToDtype(): + return func(self, *args, **kwargs) + + return wrapper + + +class DisableTensorToDtype: + def __enter__(self): + self.original_to = torch.Tensor.to + + def modified_to(tensor, *args, **kwargs): + # remove dtype from args if present + args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] + if "dtype" in kwargs: + kwargs.pop("dtype") + return self.original_to(tensor, *args, **kwargs) + + torch.Tensor.to = modified_to + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.Tensor.to = self.original_to diff --git a/finetrainers/state.py b/finetrainers/state.py new file mode 100644 index 0000000000000000000000000000000000000000..15a92e23da840af7e6920d20ea6cd4252feb47ed --- /dev/null +++ b/finetrainers/state.py @@ -0,0 +1,24 @@ +import torch +from accelerate import Accelerator + + +class State: + # Training state + seed: int = None + model_name: str = None + accelerator: Accelerator = None + weight_dtype: torch.dtype = None + train_epochs: int = None + train_steps: int = None + overwrote_max_train_steps: bool = False + num_trainable_parameters: int = 0 + learning_rate: float = None + train_batch_size: int = None + generator: torch.Generator = None + num_update_steps_per_epoch: int = None + + # Hub state + repo_id: str = None + + # Artifacts state + output_dir: str = None diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0fba79de9ea8a52b9f62d8f3826efe7dfcd50887 --- /dev/null +++ b/finetrainers/trainer.py @@ -0,0 +1,1207 @@ +import json +import logging +import math +import os +import random +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import diffusers +import torch +import torch.backends +import transformers +import wandb +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + gather_object, + set_seed, +) +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video, load_image, load_video +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from tqdm import tqdm + +from .args import Args, validate_args +from .constants import ( + FINETRAINERS_LOG_LEVEL, + PRECOMPUTED_CONDITIONS_DIR_NAME, + PRECOMPUTED_DIR_NAME, + PRECOMPUTED_LATENTS_DIR_NAME, +) +from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset +from .hooks import apply_layerwise_upcasting +from .models import get_config_from_model_name +from .patches import perform_peft_patches +from .state import State +from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from +from .utils.data_utils import should_perform_precomputation +from .utils.diffusion_utils import ( + get_scheduler_alphas, + get_scheduler_sigmas, + prepare_loss_weights, + prepare_sigmas, + prepare_target, +) +from .utils.file_utils import string_to_filename +from .utils.hub_utils import save_model_card +from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous +from .utils.model_utils import resolve_vae_cls_from_ckpt_path +from .utils.optimizer_utils import get_optimizer +from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model + + +logger = get_logger("finetrainers") +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +class Trainer: + def __init__(self, args: Args) -> None: + validate_args(args) + + self.args = args + self.args.seed = self.args.seed or datetime.now().year + self.state = State() + + # Tokenizers + self.tokenizer = None + self.tokenizer_2 = None + self.tokenizer_3 = None + + # Text encoders + self.text_encoder = None + self.text_encoder_2 = None + self.text_encoder_3 = None + + # Denoisers + self.transformer = None + self.unet = None + + # Autoencoders + self.vae = None + + # Scheduler + self.scheduler = None + + self.transformer_config = None + self.vae_config = None + + self._init_distributed() + self._init_logging() + self._init_directories_and_repositories() + self._init_config_options() + + # Peform any patches needed for training + if len(self.args.layerwise_upcasting_modules) > 0: + perform_peft_patches() + # TODO(aryan): handle text encoders + # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]): + # perform_text_encoder_patches() + + self.state.model_name = self.args.model_name + self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type) + + def prepare_dataset(self) -> None: + # TODO(aryan): Make a background process for fetching + logger.info("Initializing dataset and dataloader") + + self.dataset = ImageOrVideoDatasetWithResizing( + data_root=self.args.data_root, + caption_column=self.args.caption_column, + video_column=self.args.video_column, + resolution_buckets=self.args.video_resolution_buckets, + dataset_file=self.args.dataset_file, + id_token=self.args.id_token, + remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes, + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=1, + sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), + collate_fn=self.model_config.get("collate_fn"), + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.pin_memory, + ) + + def prepare_models(self) -> None: + logger.info("Initializing models") + + load_components_kwargs = self._get_load_components_kwargs() + condition_components, latent_components, diffusion_components = {}, {}, {} + if not self.args.precompute_conditions: + # To download the model files first on the main process (if not already present) + # and then load the cached files afterward from the other processes. + with self.state.accelerator.main_process_first(): + condition_components = self.model_config["load_condition_models"](**load_components_kwargs) + latent_components = self.model_config["load_latent_models"](**load_components_kwargs) + diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs) + + components = {} + components.update(condition_components) + components.update(latent_components) + components.update(diffusion_components) + self._set_components(components) + + if self.vae is not None: + if self.args.enable_slicing: + self.vae.enable_slicing() + if self.args.enable_tiling: + self.vae.enable_tiling() + + def prepare_precomputations(self) -> None: + if not self.args.precompute_conditions: + return + + logger.info("Initializing precomputations") + + if self.args.batch_size != 1: + raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.") + + def collate_fn(batch): + latent_conditions = [x["latent_conditions"] for x in batch] + text_conditions = [x["text_conditions"] for x in batch] + batched_latent_conditions = {} + batched_text_conditions = {} + for key in list(latent_conditions[0].keys()): + if torch.is_tensor(latent_conditions[0][key]): + batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0) + else: + # TODO(aryan): implement batch sampler for precomputed latents + batched_latent_conditions[key] = [x[key] for x in latent_conditions][0] + for key in list(text_conditions[0].keys()): + if torch.is_tensor(text_conditions[0][key]): + batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0) + else: + # TODO(aryan): implement batch sampler for precomputed latents + batched_text_conditions[key] = [x[key] for x in text_conditions][0] + return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions} + + cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path) + precomputation_dir = ( + Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" + ) + should_precompute = should_perform_precomputation(precomputation_dir) + if not should_precompute: + logger.info("Precomputed conditions and latents found. Loading precomputed data.") + self.dataloader = torch.utils.data.DataLoader( + PrecomputedDataset( + data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id + ), + batch_size=self.args.batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.pin_memory, + ) + return + + logger.info("Precomputed conditions and latents not found. Running precomputation.") + + # At this point, no models are loaded, so we need to load and precompute conditions and latents + with self.state.accelerator.main_process_first(): + condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs()) + self._set_components(condition_components) + self._move_components_to_device() + self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3]) + + if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty": + logger.warning( + "Caption dropout is not supported with precomputation yet. This will be supported in the future." + ) + + conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + conditions_dir.mkdir(parents=True, exist_ok=True) + latents_dir.mkdir(parents=True, exist_ok=True) + + accelerator = self.state.accelerator + + # Precompute conditions + progress_bar = tqdm( + range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes), + desc="Precomputing conditions", + disable=not accelerator.is_local_main_process, + ) + index = 0 + for i, data in enumerate(self.dataset): + if i % accelerator.num_processes != accelerator.process_index: + continue + + logger.debug( + f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}" + ) + + text_conditions = self.model_config["prepare_conditions"]( + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + tokenizer_3=self.tokenizer_3, + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + text_encoder_3=self.text_encoder_3, + prompt=data["prompt"], + device=accelerator.device, + dtype=self.args.transformer_dtype, + ) + filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt" + torch.save(text_conditions, filename.as_posix()) + index += 1 + progress_bar.update(1) + self._delete_components() + + memory_statistics = get_memory_statistics() + logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(accelerator.device) + + # Precompute latents + with self.state.accelerator.main_process_first(): + latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs()) + self._set_components(latent_components) + self._move_components_to_device() + self._disable_grad_for_components([self.vae]) + + if self.vae is not None: + if self.args.enable_slicing: + self.vae.enable_slicing() + if self.args.enable_tiling: + self.vae.enable_tiling() + + progress_bar = tqdm( + range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes), + desc="Precomputing latents", + disable=not accelerator.is_local_main_process, + ) + index = 0 + for i, data in enumerate(self.dataset): + if i % accelerator.num_processes != accelerator.process_index: + continue + + logger.debug( + f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}" + ) + + latent_conditions = self.model_config["prepare_latents"]( + vae=self.vae, + image_or_video=data["video"].unsqueeze(0), + device=accelerator.device, + dtype=self.args.transformer_dtype, + generator=self.state.generator, + precompute=True, + ) + filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt" + torch.save(latent_conditions, filename.as_posix()) + index += 1 + progress_bar.update(1) + self._delete_components() + + accelerator.wait_for_everyone() + logger.info("Precomputation complete") + + memory_statistics = get_memory_statistics() + logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(accelerator.device) + + # Update dataloader to use precomputed conditions and latents + self.dataloader = torch.utils.data.DataLoader( + PrecomputedDataset( + data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id + ), + batch_size=self.args.batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.pin_memory, + ) + + def prepare_trainable_parameters(self) -> None: + logger.info("Initializing trainable parameters") + + with self.state.accelerator.main_process_first(): + diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs()) + self._set_components(diffusion_components) + + components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae] + self._disable_grad_for_components(components) + + if self.args.training_type == "full-finetune": + logger.info("Finetuning transformer with no additional parameters") + self._enable_grad_for_components([self.transformer]) + else: + logger.info("Finetuning transformer with PEFT parameters") + self._disable_grad_for_components([self.transformer]) + + # Layerwise upcasting must be applied before adding the LoRA adapter. + # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on + # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. + if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: + apply_layerwise_upcasting( + self.transformer, + storage_dtype=self.args.layerwise_upcasting_storage_dtype, + compute_dtype=self.args.transformer_dtype, + skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, + non_blocking=True, + ) + + self._move_components_to_device() + + if self.args.gradient_checkpointing: + self.transformer.enable_gradient_checkpointing() + + if self.args.training_type == "lora": + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self.args.target_modules, + ) + self.transformer.add_adapter(transformer_lora_config) + else: + transformer_lora_config = None + + # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 + # even if layerwise upcasting. Would be nice to have a test as well + + self.register_saving_loading_hooks(transformer_lora_config) + + def register_saving_loading_hooks(self, transformer_lora_config): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if self.state.accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance( + unwrap_model(self.state.accelerator, model), + type(unwrap_model(self.state.accelerator, self.transformer)), + ): + model = unwrap_model(self.state.accelerator, model) + if self.args.training_type == "lora": + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + if self.args.training_type == "lora": + self.model_config["pipeline_cls"].save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + else: + model.save_pretrained(os.path.join(output_dir, "transformer")) + + # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need + # to able to load all diffusion components from a specific checkpoint folder during validation, we need to + # ensure the scheduler config is serialized as well. + self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler")) + + def load_model_hook(models, input_dir): + if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + if isinstance( + unwrap_model(self.state.accelerator, model), + type(unwrap_model(self.state.accelerator, self.transformer)), + ): + transformer_ = unwrap_model(self.state.accelerator, model) + else: + raise ValueError( + f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}" + ) + else: + transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__ + + if self.args.training_type == "lora": + transformer_ = transformer_cls_.from_pretrained( + self.args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir) + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("transformer.") + } + incompatible_keys = set_peft_model_state_dict( + transformer_, transformer_state_dict, adapter_name="default" + ) + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + else: + transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer")) + + self.state.accelerator.register_save_state_pre_hook(save_model_hook) + self.state.accelerator.register_load_state_pre_hook(load_model_hook) + + def prepare_optimizer(self) -> None: + logger.info("Initializing optimizer and lr scheduler") + + self.state.train_epochs = self.args.train_epochs + self.state.train_steps = self.args.train_steps + + # Make sure the trainable params are in float32 + if self.args.training_type == "lora": + cast_training_params([self.transformer], dtype=torch.float32) + + self.state.learning_rate = self.args.lr + if self.args.scale_lr: + self.state.learning_rate = ( + self.state.learning_rate + * self.args.gradient_accumulation_steps + * self.args.batch_size + * self.state.accelerator.num_processes + ) + + transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters())) + transformer_parameters_with_lr = { + "params": transformer_trainable_parameters, + "lr": self.state.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters) + + use_deepspeed_opt = ( + self.state.accelerator.state.deepspeed_plugin is not None + and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config + ) + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=self.args.optimizer, + learning_rate=self.state.learning_rate, + beta1=self.args.beta1, + beta2=self.args.beta2, + beta3=self.args.beta3, + epsilon=self.args.epsilon, + weight_decay=self.args.weight_decay, + use_8bit=self.args.use_8bit_bnb, + use_deepspeed=use_deepspeed_opt, + ) + + num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps) + if self.state.train_steps is None: + self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch + self.state.overwrote_max_train_steps = True + + use_deepspeed_lr_scheduler = ( + self.state.accelerator.state.deepspeed_plugin is not None + and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config + ) + total_training_steps = self.state.train_steps * self.state.accelerator.num_processes + num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes + + if use_deepspeed_lr_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=self.args.lr_scheduler, + optimizer=optimizer, + total_num_steps=total_training_steps, + num_warmup_steps=num_warmup_steps, + ) + else: + lr_scheduler = get_scheduler( + name=self.args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_training_steps, + num_cycles=self.args.lr_num_cycles, + power=self.args.lr_power, + ) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def prepare_for_training(self) -> None: + self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare( + self.transformer, self.optimizer, self.dataloader, self.lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps) + if self.state.overwrote_max_train_steps: + self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch) + self.state.num_update_steps_per_epoch = num_update_steps_per_epoch + + def prepare_trackers(self) -> None: + logger.info("Initializing trackers") + + tracker_name = self.args.tracker_name or "finetrainers-experiment" + self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info()) + + def train(self) -> None: + logger.info("Starting training") + + memory_statistics = get_memory_statistics() + logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") + + if self.vae_config is None: + # If we've precomputed conditions and latents already, and are now re-using it, we will never load + # the VAE so self.vae_config will not be set. So, we need to load it here. + vae_cls = resolve_vae_cls_from_ckpt_path( + self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir + ) + vae_config = vae_cls.load_config( + self.args.pretrained_model_name_or_path, + subfolder="vae", + revision=self.args.revision, + cache_dir=self.args.cache_dir, + ) + self.vae_config = FrozenDict(**vae_config) + + # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need + # to able to load all diffusion components from a specific checkpoint folder during validation, we need to + # ensure the scheduler config is serialized as well. + if self.args.training_type == "full-finetune": + self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler")) + + self.state.train_batch_size = ( + self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps + ) + info = { + "trainable parameters": self.state.num_trainable_parameters, + "total samples": len(self.dataset), + "train epochs": self.state.train_epochs, + "train steps": self.state.train_steps, + "batches per device": self.args.batch_size, + "total batches observed per epoch": len(self.dataloader), + "train batch size": self.state.train_batch_size, + "gradient accumulation steps": self.args.gradient_accumulation_steps, + } + logger.info(f"Training configuration: {json.dumps(info, indent=4)}") + + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + + # Potentially load in the weights and states from a previous save + ( + resume_from_checkpoint_path, + initial_global_step, + global_step, + first_epoch, + ) = get_latest_ckpt_path_to_resume_from( + resume_from_checkpoint=self.args.resume_from_checkpoint, + num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, + output_dir=self.args.output_dir, + ) + if resume_from_checkpoint_path: + self.state.accelerator.load_state(resume_from_checkpoint_path) + + progress_bar = tqdm( + range(0, self.state.train_steps), + initial=initial_global_step, + desc="Training steps", + disable=not self.state.accelerator.is_local_main_process, + ) + + accelerator = self.state.accelerator + generator = torch.Generator(device=accelerator.device) + if self.args.seed is not None: + generator = generator.manual_seed(self.args.seed) + self.state.generator = generator + + scheduler_sigmas = get_scheduler_sigmas(self.scheduler) + scheduler_sigmas = ( + scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32) + if scheduler_sigmas is not None + else None + ) + scheduler_alphas = get_scheduler_alphas(self.scheduler) + scheduler_alphas = ( + scheduler_alphas.to(device=accelerator.device, dtype=torch.float32) + if scheduler_alphas is not None + else None + ) + + for epoch in range(first_epoch, self.state.train_epochs): + logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})") + + self.transformer.train() + models_to_accumulate = [self.transformer] + epoch_loss = 0.0 + num_loss_updates = 0 + + for step, batch in enumerate(self.dataloader): + logger.debug(f"Starting step {step + 1}") + logs = {} + + with accelerator.accumulate(models_to_accumulate): + if not self.args.precompute_conditions: + videos = batch["videos"] + prompts = batch["prompts"] + batch_size = len(prompts) + + if self.args.caption_dropout_technique == "empty": + if random.random() < self.args.caption_dropout_p: + prompts = [""] * batch_size + + latent_conditions = self.model_config["prepare_latents"]( + vae=self.vae, + image_or_video=videos, + patch_size=self.transformer_config.patch_size, + patch_size_t=self.transformer_config.patch_size_t, + device=accelerator.device, + dtype=self.args.transformer_dtype, + generator=self.state.generator, + ) + text_conditions = self.model_config["prepare_conditions"]( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + tokenizer_2=self.tokenizer_2, + text_encoder_2=self.text_encoder_2, + prompt=prompts, + device=accelerator.device, + dtype=self.args.transformer_dtype, + ) + else: + latent_conditions = batch["latent_conditions"] + text_conditions = batch["text_conditions"] + latent_conditions["latents"] = DiagonalGaussianDistribution( + latent_conditions["latents"] + ).sample(self.state.generator) + + # This method should only be called for precomputed latents. + # TODO(aryan): rename this in separate PR + latent_conditions = self.model_config["post_latent_preparation"]( + vae_config=self.vae_config, + patch_size=self.transformer_config.patch_size, + patch_size_t=self.transformer_config.patch_size_t, + **latent_conditions, + ) + align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype) + align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype) + batch_size = latent_conditions["latents"].shape[0] + + latent_conditions = make_contiguous(latent_conditions) + text_conditions = make_contiguous(text_conditions) + + if self.args.caption_dropout_technique == "zero": + if random.random() < self.args.caption_dropout_p: + text_conditions["prompt_embeds"].fill_(0) + text_conditions["prompt_attention_mask"].fill_(False) + + # TODO(aryan): refactor later + if "pooled_prompt_embeds" in text_conditions: + text_conditions["pooled_prompt_embeds"].fill_(0) + + sigmas = prepare_sigmas( + scheduler=self.scheduler, + sigmas=scheduler_sigmas, + batch_size=batch_size, + num_train_timesteps=self.scheduler.config.num_train_timesteps, + flow_weighting_scheme=self.args.flow_weighting_scheme, + flow_logit_mean=self.args.flow_logit_mean, + flow_logit_std=self.args.flow_logit_std, + flow_mode_scale=self.args.flow_mode_scale, + device=accelerator.device, + generator=self.state.generator, + ) + timesteps = (sigmas * 1000.0).long() + + noise = torch.randn( + latent_conditions["latents"].shape, + generator=self.state.generator, + device=accelerator.device, + dtype=self.args.transformer_dtype, + ) + sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim) + + # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of + # scheduler and calculate the noisy latents accordingly. Look into this later. + if "calculate_noisy_latents" in self.model_config.keys(): + noisy_latents = self.model_config["calculate_noisy_latents"]( + scheduler=self.scheduler, + noise=noise, + latents=latent_conditions["latents"], + timesteps=timesteps, + ) + else: + # Default to flow-matching noise addition + noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise + noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype) + + latent_conditions.update({"noisy_latents": noisy_latents}) + + weights = prepare_loss_weights( + scheduler=self.scheduler, + alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, + sigmas=sigmas, + flow_weighting_scheme=self.args.flow_weighting_scheme, + ) + weights = expand_tensor_dims(weights, noise.ndim) + + pred = self.model_config["forward_pass"]( + transformer=self.transformer, + scheduler=self.scheduler, + timesteps=timesteps, + **latent_conditions, + **text_conditions, + ) + target = prepare_target( + scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"] + ) + + loss = weights.float() * (pred["latents"].float() - target.float()).pow(2) + # Average loss across all but batch dimension + loss = loss.mean(list(range(1, loss.ndim))) + # Average loss across batch dimension + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + if accelerator.distributed_type == DistributedType.DEEPSPEED: + grad_norm = self.transformer.get_global_grad_norm() + # In some cases the grad norm may not return a float + if torch.is_tensor(grad_norm): + grad_norm = grad_norm.item() + else: + grad_norm = accelerator.clip_grad_norm_( + self.transformer.parameters(), self.args.max_grad_norm + ) + if torch.is_tensor(grad_norm): + grad_norm = grad_norm.item() + + logs["grad_norm"] = grad_norm + + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # Checkpointing + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % self.args.checkpointing_steps == 0: + save_path = get_intermediate_ckpt_path( + checkpointing_limit=self.args.checkpointing_limit, + step=global_step, + output_dir=self.args.output_dir, + ) + accelerator.save_state(save_path) + + # Maybe run validation + should_run_validation = ( + self.args.validation_every_n_steps is not None + and global_step % self.args.validation_every_n_steps == 0 + ) + if should_run_validation: + self.validate(global_step) + + loss_item = loss.detach().item() + epoch_loss += loss_item + num_loss_updates += 1 + logs["step_loss"] = loss_item + logs["lr"] = self.lr_scheduler.get_last_lr()[0] + progress_bar.set_postfix(logs) + accelerator.log(logs, step=global_step) + + if global_step >= self.state.train_steps: + break + + if num_loss_updates > 0: + epoch_loss /= num_loss_updates + accelerator.log({"epoch_loss": epoch_loss}, step=global_step) + memory_statistics = get_memory_statistics() + logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") + + # Maybe run validation + should_run_validation = ( + self.args.validation_every_n_epochs is not None + and (epoch + 1) % self.args.validation_every_n_epochs == 0 + ) + if should_run_validation: + self.validate(global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, self.transformer) + + if self.args.training_type == "lora": + transformer_lora_layers = get_peft_model_state_dict(transformer) + + self.model_config["pipeline_cls"].save_lora_weights( + save_directory=self.args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + else: + transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer")) + accelerator.wait_for_everyone() + self.validate(step=global_step, final_validation=True) + + if accelerator.is_main_process: + if self.args.push_to_hub: + upload_folder( + repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"] + ) + + self._delete_components() + memory_statistics = get_memory_statistics() + logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + + accelerator.end_training() + + def validate(self, step: int, final_validation: bool = False) -> None: + logger.info("Starting validation") + + accelerator = self.state.accelerator + num_validation_samples = len(self.args.validation_prompts) + + if num_validation_samples == 0: + logger.warning("No validation samples found. Skipping validation.") + if accelerator.is_main_process: + if self.args.push_to_hub: + save_model_card( + args=self.args, + repo_id=self.state.repo_id, + videos=None, + validation_prompts=None, + ) + return + + self.transformer.eval() + + memory_statistics = get_memory_statistics() + logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + + pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation) + + all_processes_artifacts = [] + prompts_to_filenames = {} + for i in range(num_validation_samples): + # Skip current validation on all processes but one + if i % accelerator.num_processes != accelerator.process_index: + continue + + prompt = self.args.validation_prompts[i] + image = self.args.validation_images[i] + video = self.args.validation_videos[i] + height = self.args.validation_heights[i] + width = self.args.validation_widths[i] + num_frames = self.args.validation_num_frames[i] + frame_rate = self.args.validation_frame_rate + if image is not None: + image = load_image(image) + if video is not None: + video = load_video(video) + + logger.debug( + f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", + main_process_only=False, + ) + validation_artifacts = self.model_config["validation"]( + pipeline=pipeline, + prompt=prompt, + image=image, + video=video, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + num_videos_per_prompt=self.args.num_validation_videos_per_prompt, + generator=torch.Generator(device=accelerator.device).manual_seed( + self.args.seed if self.args.seed is not None else 0 + ), + # todo support passing `fps` for supported pipelines. + ) + + prompt_filename = string_to_filename(prompt)[:25] + artifacts = { + "image": {"type": "image", "value": image}, + "video": {"type": "video", "value": video}, + } + for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): + if artifact_value: + artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) + logger.debug( + f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", + main_process_only=False, + ) + + for index, (key, value) in enumerate(list(artifacts.items())): + artifact_type = value["type"] + artifact_value = value["value"] + if artifact_type not in ["image", "video"] or artifact_value is None: + continue + + extension = "png" if artifact_type == "image" else "mp4" + filename = "validation-" if not final_validation else "final-" + filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}" + if accelerator.is_main_process and extension == "mp4": + prompts_to_filenames[prompt] = filename + filename = os.path.join(self.args.output_dir, filename) + + if artifact_type == "image" and artifact_value: + logger.debug(f"Saving image to {filename}") + artifact_value.save(filename) + artifact_value = wandb.Image(filename) + elif artifact_type == "video" and artifact_value: + logger.debug(f"Saving video to {filename}") + # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`. + export_to_video(artifact_value, filename, fps=frame_rate) + artifact_value = wandb.Video(filename, caption=prompt) + + all_processes_artifacts.append(artifact_value) + + all_artifacts = gather_object(all_processes_artifacts) + + if accelerator.is_main_process: + tracker_key = "final" if final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "wandb": + artifact_log_dict = {} + + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + if len(image_artifacts) > 0: + artifact_log_dict["images"] = image_artifacts + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + if len(video_artifacts) > 0: + artifact_log_dict["videos"] = video_artifacts + tracker.log({tracker_key: artifact_log_dict}, step=step) + + if self.args.push_to_hub and final_validation: + video_filenames = list(prompts_to_filenames.values()) + prompts = list(prompts_to_filenames.keys()) + save_model_card( + args=self.args, + repo_id=self.state.repo_id, + videos=video_filenames, + validation_prompts=prompts, + ) + + # Remove all hooks that might have been added during pipeline initialization to the models + pipeline.remove_all_hooks() + del pipeline + + accelerator.wait_for_everyone() + + free_memory() + memory_statistics = get_memory_statistics() + logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(accelerator.device) + + if not final_validation: + self.transformer.train() + + def evaluate(self) -> None: + raise NotImplementedError("Evaluation has not been implemented yet.") + + def _init_distributed(self) -> None: + logging_dir = Path(self.args.output_dir, self.args.logging_dir) + project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs( + backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) + ) + report_to = None if self.args.report_to.lower() == "none" else self.args.report_to + + accelerator = Accelerator( + project_config=project_config, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + log_with=report_to, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + self.state.accelerator = accelerator + + if self.args.seed is not None: + self.state.seed = self.args.seed + set_seed(self.args.seed) + + def _init_logging(self) -> None: + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=FINETRAINERS_LOG_LEVEL, + ) + if self.state.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + logger.info("Initialized FineTrainers") + logger.info(self.state.accelerator.state, main_process_only=False) + + def _init_directories_and_repositories(self) -> None: + if self.state.accelerator.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + self.state.output_dir = Path(self.args.output_dir) + + if self.args.push_to_hub: + repo_id = self.args.hub_model_id or Path(self.args.output_dir).name + self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id + + def _init_config_options(self) -> None: + # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + def _move_components_to_device(self): + if self.text_encoder is not None: + self.text_encoder = self.text_encoder.to(self.state.accelerator.device) + if self.text_encoder_2 is not None: + self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device) + if self.text_encoder_3 is not None: + self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device) + if self.transformer is not None: + self.transformer = self.transformer.to(self.state.accelerator.device) + if self.unet is not None: + self.unet = self.unet.to(self.state.accelerator.device) + if self.vae is not None: + self.vae = self.vae.to(self.state.accelerator.device) + + def _get_load_components_kwargs(self) -> Dict[str, Any]: + load_component_kwargs = { + "text_encoder_dtype": self.args.text_encoder_dtype, + "text_encoder_2_dtype": self.args.text_encoder_2_dtype, + "text_encoder_3_dtype": self.args.text_encoder_3_dtype, + "transformer_dtype": self.args.transformer_dtype, + "vae_dtype": self.args.vae_dtype, + "shift": self.args.flow_shift, + "revision": self.args.revision, + "cache_dir": self.args.cache_dir, + } + if self.args.pretrained_model_name_or_path is not None: + load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path + return load_component_kwargs + + def _set_components(self, components: Dict[str, Any]) -> None: + # Set models + self.tokenizer = components.get("tokenizer", self.tokenizer) + self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2) + self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3) + self.text_encoder = components.get("text_encoder", self.text_encoder) + self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2) + self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3) + self.transformer = components.get("transformer", self.transformer) + self.unet = components.get("unet", self.unet) + self.vae = components.get("vae", self.vae) + self.scheduler = components.get("scheduler", self.scheduler) + + # Set configs + self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config + self.vae_config = self.vae.config if self.vae is not None else self.vae_config + + def _delete_components(self) -> None: + self.tokenizer = None + self.tokenizer_2 = None + self.tokenizer_3 = None + self.text_encoder = None + self.text_encoder_2 = None + self.text_encoder_3 = None + self.transformer = None + self.unet = None + self.vae = None + self.scheduler = None + free_memory() + torch.cuda.synchronize(self.state.accelerator.device) + + def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline: + accelerator = self.state.accelerator + if not final_validation: + pipeline = self.model_config["initialize_pipeline"]( + model_id=self.args.pretrained_model_name_or_path, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + tokenizer_2=self.tokenizer_2, + text_encoder_2=self.text_encoder_2, + transformer=unwrap_model(accelerator, self.transformer), + vae=self.vae, + device=accelerator.device, + revision=self.args.revision, + cache_dir=self.args.cache_dir, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + is_training=True, + ) + else: + self._delete_components() + + # Load the transformer weights from the final checkpoint if performing full-finetune + transformer = None + if self.args.training_type == "full-finetune": + transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"] + + pipeline = self.model_config["initialize_pipeline"]( + model_id=self.args.pretrained_model_name_or_path, + transformer=transformer, + device=accelerator.device, + revision=self.args.revision, + cache_dir=self.args.cache_dir, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + is_training=False, + ) + + # Load the LoRA weights if performing LoRA finetuning + if self.args.training_type == "lora": + pipeline.load_lora_weights(self.args.output_dir) + + return pipeline + + def _disable_grad_for_components(self, components: List[torch.nn.Module]): + for component in components: + if component is not None: + component.requires_grad_(False) + + def _enable_grad_for_components(self, components: List[torch.nn.Module]): + for component in components: + if component is not None: + component.requires_grad_(True) + + def _get_training_info(self) -> dict: + args = self.args.to_dict() + + training_args = args.get("training_arguments", {}) + training_type = training_args.get("training_type", "") + + # LoRA/non-LoRA stuff. + if training_type == "full-finetune": + filtered_training_args = { + k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"} + } + else: + filtered_training_args = training_args + + # Diffusion/flow stuff. + diffusion_args = args.get("diffusion_arguments", {}) + scheduler_name = self.scheduler.__class__.__name__ + if scheduler_name != "FlowMatchEulerDiscreteScheduler": + filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} + else: + filtered_diffusion_args = diffusion_args + + # Rest of the stuff. + updated_training_info = args.copy() + updated_training_info["training_arguments"] = filtered_training_args + updated_training_info["diffusion_arguments"] = filtered_diffusion_args + return updated_training_info diff --git a/finetrainers/utils/__init__.py b/finetrainers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d06bd9def1f5d4c242cafdc03cd2d31414c1f169 --- /dev/null +++ b/finetrainers/utils/__init__.py @@ -0,0 +1,13 @@ +from .diffusion_utils import ( + default_flow_shift, + get_scheduler_alphas, + get_scheduler_sigmas, + prepare_loss_weights, + prepare_sigmas, + prepare_target, + resolution_dependent_timestep_flow_shift, +) +from .file_utils import delete_files, find_files +from .memory_utils import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous +from .optimizer_utils import get_optimizer, gradient_norm, max_gradient +from .torch_utils import unwrap_model diff --git a/finetrainers/utils/checkpointing.py b/finetrainers/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..01dbea0a029144f53ff8587b887204515a7e3250 --- /dev/null +++ b/finetrainers/utils/checkpointing.py @@ -0,0 +1,64 @@ +import os +from typing import Tuple + +from accelerate.logging import get_logger + +from ..constants import FINETRAINERS_LOG_LEVEL +from ..utils.file_utils import delete_files, find_files + + +logger = get_logger("finetrainers") +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +def get_latest_ckpt_path_to_resume_from( + resume_from_checkpoint: str, num_update_steps_per_epoch: int, output_dir: str +) -> Tuple[str, int, int, int]: + if not resume_from_checkpoint: + initial_global_step = 0 + global_step = 0 + first_epoch = 0 + resume_from_checkpoint_path = None + else: + if resume_from_checkpoint != "latest": + path = os.path.basename(resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + resume_from_checkpoint = None + initial_global_step = 0 + global_step = 0 + first_epoch = 0 + resume_from_checkpoint_path = None + else: + logger.info(f"Resuming from checkpoint {path}") + resume_from_checkpoint_path = os.path.join(output_dir, path) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch + + +def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str: + # before saving state, check if this save would set us over the `checkpointing_limit` + if checkpointing_limit is not None: + checkpoints = find_files(output_dir, prefix="checkpoint") + + # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpointing_limit: + num_to_remove = len(checkpoints) - checkpointing_limit + 1 + checkpoints_to_remove = [os.path.join(output_dir, x) for x in checkpoints[0:num_to_remove]] + delete_files(checkpoints_to_remove) + + logger.info(f"Checkpointing at step {step}") + save_path = os.path.join(output_dir, f"checkpoint-{step}") + logger.info(f"Saving state to {save_path}") + return save_path diff --git a/finetrainers/utils/data_utils.py b/finetrainers/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..284dd1a9c9c1d91633dd0ae7e032a0cb507e8b2c --- /dev/null +++ b/finetrainers/utils/data_utils.py @@ -0,0 +1,35 @@ +from pathlib import Path +from typing import Union + +from accelerate.logging import get_logger + +from ..constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME + + +logger = get_logger("finetrainers") + + +def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool: + if isinstance(precomputation_dir, str): + precomputation_dir = Path(precomputation_dir) + conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + if conditions_dir.exists() and latents_dir.exists(): + num_files_conditions = len(list(conditions_dir.glob("*.pt"))) + num_files_latents = len(list(latents_dir.glob("*.pt"))) + if num_files_conditions != num_files_latents: + logger.warning( + f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})." + f"Cleaning up precomputed directories and re-running precomputation." + ) + # clean up precomputed directories + for file in conditions_dir.glob("*.pt"): + file.unlink() + for file in latents_dir.glob("*.pt"): + file.unlink() + return True + if num_files_conditions > 0: + logger.info(f"Found {num_files_conditions} precomputed conditions and latents.") + return False + logger.info("Precomputed data not found. Running precomputation.") + return True diff --git a/finetrainers/utils/diffusion_utils.py b/finetrainers/utils/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dbbccb28ca66e2b9f6f933180a857f45b7d7ae90 --- /dev/null +++ b/finetrainers/utils/diffusion_utils.py @@ -0,0 +1,145 @@ +import math +from typing import Optional, Union + +import torch +from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.training_utils import compute_loss_weighting_for_sd3 + + +# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47 +def resolution_dependent_timestep_flow_shift( + latents: torch.Tensor, + sigmas: torch.Tensor, + base_image_seq_len: int = 256, + max_image_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> torch.Tensor: + image_or_video_sequence_length = 0 + if latents.ndim == 4: + image_or_video_sequence_length = latents.shape[2] * latents.shape[3] + elif latents.ndim == 5: + image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4] + else: + raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor") + + m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) + b = base_shift - m * base_image_seq_len + mu = m * image_or_video_sequence_length + b + sigmas = default_flow_shift(latents, sigmas, shift=mu) + return sigmas + + +def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor: + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + return sigmas + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, + device: torch.device = torch.device("cpu"), + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + r""" + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device=device, generator=generator) + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device=device, generator=generator) + return u + + +def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + return None + elif isinstance(scheduler, CogVideoXDDIMScheduler): + return scheduler.alphas_cumprod.clone() + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + +def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + return scheduler.sigmas.clone() + elif isinstance(scheduler, CogVideoXDDIMScheduler): + return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps) + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + +def prepare_sigmas( + scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], + sigmas: torch.Tensor, + batch_size: int, + num_train_timesteps: int, + flow_weighting_scheme: str = "none", + flow_logit_mean: float = 0.0, + flow_logit_std: float = 1.0, + flow_mode_scale: float = 1.29, + device: torch.device = torch.device("cpu"), + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + weights = compute_density_for_timestep_sampling( + weighting_scheme=flow_weighting_scheme, + batch_size=batch_size, + logit_mean=flow_logit_mean, + logit_std=flow_logit_std, + mode_scale=flow_mode_scale, + device=device, + generator=generator, + ) + indices = (weights * num_train_timesteps).long() + elif isinstance(scheduler, CogVideoXDDIMScheduler): + # TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes. + weights = torch.rand(size=(batch_size,), device=device, generator=generator) + indices = (weights * num_train_timesteps).long() + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + return sigmas[indices] + + +def prepare_loss_weights( + scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], + alphas: Optional[torch.Tensor] = None, + sigmas: Optional[torch.Tensor] = None, + flow_weighting_scheme: str = "none", +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme) + elif isinstance(scheduler, CogVideoXDDIMScheduler): + # SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas). + # TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results. + return 1 / (1 - alphas) + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + +def prepare_target( + scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], + noise: torch.Tensor, + latents: torch.Tensor, +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + target = noise - latents + elif isinstance(scheduler, CogVideoXDDIMScheduler): + target = latents + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + return target diff --git a/finetrainers/utils/file_utils.py b/finetrainers/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb731771ba9bc7e07a273ca8947949dfd572b465 --- /dev/null +++ b/finetrainers/utils/file_utils.py @@ -0,0 +1,44 @@ +import logging +import os +import shutil +from pathlib import Path +from typing import List, Union + + +logger = logging.getLogger("finetrainers") +logger.setLevel(os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")) + + +def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]: + if not isinstance(dir, Path): + dir = Path(dir) + if not dir.exists(): + return [] + checkpoints = os.listdir(dir.as_posix()) + checkpoints = [c for c in checkpoints if c.startswith(prefix)] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + return checkpoints + + +def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: + if not isinstance(dirs, list): + dirs = [dirs] + dirs = [Path(d) if isinstance(d, str) else d for d in dirs] + logger.info(f"Deleting files: {dirs}") + for dir in dirs: + if not dir.exists(): + continue + shutil.rmtree(dir, ignore_errors=True) + + +def string_to_filename(s: str) -> str: + return ( + s.replace(" ", "-") + .replace("/", "-") + .replace(":", "-") + .replace(".", "-") + .replace(",", "-") + .replace(";", "-") + .replace("!", "-") + .replace("?", "-") + ) diff --git a/finetrainers/utils/hub_utils.py b/finetrainers/utils/hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef865407ef9a1d41300c2e544d622a62b498989b --- /dev/null +++ b/finetrainers/utils/hub_utils.py @@ -0,0 +1,84 @@ +import os +from typing import List, Union + +import numpy as np +import wandb +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from PIL import Image + + +def save_model_card( + args, + repo_id: str, + videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]], + validation_prompts: List[str], + fps: int = 30, +) -> None: + widget_dict = [] + output_dir = str(args.output_dir) + if videos is not None and len(videos) > 0: + for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)): + if not isinstance(video, str): + export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"}, + } + ) + + training_type = "Full" if args.training_type == "full-finetune" else "LoRA" + model_description = f""" +# {training_type} Finetune + + + +## Model description + +This is a {training_type.lower()} finetune of model: `{args.pretrained_model_name_or_path}`. + +The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers). + +`id_token` used: {args.id_token} (if it's not `None`, it should be used in the prompts.) + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +TODO +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. +""" + if wandb.run and wandb.run.url: + model_description += f""" +Find out the wandb run URL and training configurations [here]({wandb.run.url}). +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + base_model=args.pretrained_model_name_or_path, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "finetrainers", + "template:sd-lora", + ] + if training_type == "Full": + tags.append("full-finetune") + else: + tags.append("lora") + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(args.output_dir, "README.md")) diff --git a/finetrainers/utils/memory_utils.py b/finetrainers/utils/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7616b190ebe474484e4aaf438b9a80eabf6ab66 --- /dev/null +++ b/finetrainers/utils/memory_utils.py @@ -0,0 +1,58 @@ +import gc +from typing import Any, Dict, Union + +import torch +from accelerate.logging import get_logger + + +logger = get_logger("finetrainers") + + +def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: + memory_allocated = None + memory_reserved = None + max_memory_allocated = None + max_memory_reserved = None + + if torch.cuda.is_available(): + device = torch.cuda.current_device() + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + elif torch.backends.mps.is_available(): + memory_allocated = torch.mps.current_allocated_memory() + + else: + logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") + + return { + "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), + "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), + "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), + "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), + } + + +def bytes_to_gigabytes(x: int) -> float: + if x is not None: + return x / 1024**3 + + +def free_memory() -> None: + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + # TODO(aryan): handle non-cuda devices + + +def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if isinstance(x, torch.Tensor): + return x.contiguous() + elif isinstance(x, dict): + return {k: make_contiguous(v) for k, v in x.items()} + else: + return x diff --git a/finetrainers/utils/model_utils.py b/finetrainers/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1451ebff3d7c29e60d856cd41f56b358b044ffc9 --- /dev/null +++ b/finetrainers/utils/model_utils.py @@ -0,0 +1,25 @@ +import importlib +import json +import os + +from huggingface_hub import hf_hub_download + + +def resolve_vae_cls_from_ckpt_path(ckpt_path, **kwargs): + ckpt_path = str(ckpt_path) + if os.path.exists(str(ckpt_path)) and os.path.isdir(ckpt_path): + index_path = os.path.join(ckpt_path, "model_index.json") + else: + revision = kwargs.get("revision", None) + cache_dir = kwargs.get("cache_dir", None) + index_path = hf_hub_download( + repo_id=ckpt_path, filename="model_index.json", revision=revision, cache_dir=cache_dir + ) + + with open(index_path, "r") as f: + model_index_dict = json.load(f) + assert "vae" in model_index_dict, "No VAE found in the modelx index dict." + + vae_cls_config = model_index_dict["vae"] + library = importlib.import_module(vae_cls_config[0]) + return getattr(library, vae_cls_config[1]) diff --git a/finetrainers/utils/optimizer_utils.py b/finetrainers/utils/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84c215b51e8e458ff3702fd65be223bce7a7aeb9 --- /dev/null +++ b/finetrainers/utils/optimizer_utils.py @@ -0,0 +1,178 @@ +import inspect + +import torch +from accelerate.logging import get_logger + + +logger = get_logger("finetrainers") + + +def get_optimizer( + params_to_optimize, + optimizer_name: str = "adam", + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.98, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + prodigy_decouple: bool = False, + prodigy_use_bias_correction: bool = False, + prodigy_safeguard_warmup: bool = False, + use_8bit: bool = False, + use_4bit: bool = False, + use_torchao: bool = False, + use_deepspeed: bool = False, + use_cpu_offload_optimizer: bool = False, + offload_gradients: bool = False, +) -> torch.optim.Optimizer: + optimizer_name = optimizer_name.lower() + + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=learning_rate, + betas=(beta1, beta2), + eps=epsilon, + weight_decay=weight_decay, + ) + + # TODO: consider moving the validation logic to `args.py` when we have torchao. + if use_8bit and use_4bit: + raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") + + if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: + try: + import torchao # noqa + + except ImportError: + raise ImportError( + "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." + ) + + if not use_torchao and use_4bit: + raise ValueError("4-bit Optimizers are only supported with torchao.") + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy", "came"] + if optimizer_name not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." + ) + optimizer_name = "adamw" + + if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: + raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") + + if use_8bit: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if optimizer_name == "adamw": + if use_torchao: + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit + + optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + else: + optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "adam": + if use_torchao: + from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + + optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam + else: + optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + init_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "beta3": beta3, + "eps": epsilon, + "weight_decay": weight_decay, + "decouple": prodigy_decouple, + "use_bias_correction": prodigy_use_bias_correction, + "safeguard_warmup": prodigy_safeguard_warmup, + } + + elif optimizer_name == "came": + try: + import came_pytorch + except ImportError: + raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") + + optimizer_class = came_pytorch.CAME + + init_kwargs = { + "lr": learning_rate, + "eps": (1e-30, 1e-16), + "betas": (beta1, beta2, beta3), + "weight_decay": weight_decay, + } + + if use_cpu_offload_optimizer: + from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + + if "fused" in inspect.signature(optimizer_class.__init__).parameters: + init_kwargs.update({"fused": True}) + + optimizer = CPUOffloadOptimizer( + params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs + ) + else: + optimizer = optimizer_class(params_to_optimize, **init_kwargs) + + return optimizer + + +def gradient_norm(parameters): + norm = 0 + for param in parameters: + if param.grad is None: + continue + local_norm = param.grad.detach().data.norm(2) + norm += local_norm.item() ** 2 + norm = norm**0.5 + return norm + + +def max_gradient(parameters): + max_grad_value = float("-inf") + for param in parameters: + if param.grad is None: + continue + local_max_grad = param.grad.detach().data.abs().max() + max_grad_value = max(max_grad_value, local_max_grad.item()) + return max_grad_value diff --git a/finetrainers/utils/torch_utils.py b/finetrainers/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c6ef5df9ea5a832e671cbaaff008cd6f48078b0 --- /dev/null +++ b/finetrainers/utils/torch_utils.py @@ -0,0 +1,35 @@ +from typing import Dict, Optional, Union + +import torch +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + + +def unwrap_model(accelerator: Accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + +def align_device_and_dtype( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + if isinstance(x, torch.Tensor): + if device is not None: + x = x.to(device) + if dtype is not None: + x = x.to(dtype) + elif isinstance(x, dict): + if device is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + if dtype is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + return x + + +def expand_tensor_dims(tensor, ndim): + while len(tensor.shape) < ndim: + tensor = tensor.unsqueeze(-1) + return tensor diff --git a/finetrainers_utils.py b/finetrainers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5931e0ec9ade1a650f62e62542a559440ea143c0 --- /dev/null +++ b/finetrainers_utils.py @@ -0,0 +1,126 @@ +import gradio as gr +from pathlib import Path +import logging +import shutil +from typing import Any, Optional, Dict, List, Union, Tuple +from config import STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES +from utils import extract_scene_info, make_archive, is_image_file, is_video_file + +logger = logging.getLogger(__name__) + +def prepare_finetrainers_dataset() -> Tuple[Path, Path]: + """make sure we have a Finetrainers-compatible dataset structure + + Checks that we have: + training/ + ├── prompt.txt # All captions, one per line + ├── videos.txt # All video paths, one per line + └── videos/ # Directory containing all mp4 files + ├── 00000.mp4 + ├── 00001.mp4 + └── ... + Returns: + Tuple of (videos_file_path, prompts_file_path) + """ + + # Verifies the videos subdirectory + TRAINING_VIDEOS_PATH.mkdir(exist_ok=True) + + # Clear existing training lists + for f in TRAINING_PATH.glob("*"): + if f.is_file(): + if f.name in ["videos.txt", "prompts.txt"]: + f.unlink() + + videos_file = TRAINING_PATH / "videos.txt" + prompts_file = TRAINING_PATH / "prompts.txt" # Note: Changed from prompt.txt to prompts.txt to match our config + + media_files = [] + captions = [] + # Process all video files from the videos subdirectory + for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))): + caption_file = file.with_suffix('.txt') + if caption_file.exists(): + # Normalize caption to single line + caption = caption_file.read_text().strip() + caption = ' '.join(caption.split()) + + # Use relative path from training root + relative_path = f"videos/{file.name}" + media_files.append(relative_path) + captions.append(caption) + + # Clean up the caption file since it's now in prompts.txt + # EDIT well you know what, let's keep it, otherwise running the function + # twice might cause some errors + # caption_file.unlink() + + # Write files if we have content + if media_files and captions: + videos_file.write_text('\n'.join(media_files)) + prompts_file.write_text('\n'.join(captions)) + + else: + raise ValueError("No valid video/caption pairs found in training directory") + # Verify file contents + with open(videos_file) as vf: + video_lines = [l.strip() for l in vf.readlines() if l.strip()] + with open(prompts_file) as pf: + prompt_lines = [l.strip() for l in pf.readlines() if l.strip()] + + if len(video_lines) != len(prompt_lines): + raise ValueError(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts") + + return videos_file, prompts_file + +def copy_files_to_training_dir(prompt_prefix: str) -> int: + """Just copy files over, with no destruction""" + + gr.Info("Copying assets to the training dataset..") + + # Find files needing captions + video_files = list(STAGING_PATH.glob("*.mp4")) + image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)] + all_files = video_files + image_files + + nb_copied_pairs = 0 + + for file_path in all_files: + + caption = "" + file_caption_path = file_path.with_suffix('.txt') + if file_caption_path.exists(): + logger.debug(f"Found caption file: {file_caption_path}") + caption = file_caption_path.read_text() + + # Get parent caption if this is a clip + parent_caption = "" + if "___" in file_path.stem: + parent_name, _ = extract_scene_info(file_path.stem) + #print(f"parent_name is {parent_name}") + parent_caption_path = STAGING_PATH / f"{parent_name}.txt" + if parent_caption_path.exists(): + logger.debug(f"Found parent caption file: {parent_caption_path}") + parent_caption = parent_caption_path.read_text().strip() + + target_file_path = TRAINING_VIDEOS_PATH / file_path.name + + target_caption_path = target_file_path.with_suffix('.txt') + + if parent_caption and not caption.endswith(parent_caption): + caption = f"{caption}\n{parent_caption}" + + if prompt_prefix and not caption.startswith(prompt_prefix): + caption = f"{prompt_prefix}{caption}" + + # make sure we only copy over VALID pairs + if caption: + target_caption_path.write_text(caption) + shutil.copy2(file_path, target_file_path) + nb_copied_pairs += 1 + + prepare_finetrainers_dataset() + + gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)") + + return nb_copied_pairs diff --git a/image_preprocessing.py b/image_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..276a067837da9c1e34b550d935c7d1c41b7d08b3 --- /dev/null +++ b/image_preprocessing.py @@ -0,0 +1,116 @@ +import cv2 +import numpy as np +from pathlib import Path +from PIL import Image +import pillow_avif +import logging +from config import NORMALIZE_IMAGES_TO, JPEG_QUALITY + +logger = logging.getLogger(__name__) + +def normalize_image(input_path: Path, output_path: Path) -> bool: + """Convert image to normalized format (PNG or JPEG) and optionally remove black bars + + Args: + input_path: Source image path + output_path: Target path + + Returns: + bool: True if successful, False otherwise + """ + try: + # Open image with PIL + with Image.open(input_path) as img: + # Convert to RGB if needed + if img.mode in ('RGBA', 'LA'): + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'RGBA': + background.paste(img, mask=img.split()[3]) + else: + background.paste(img, mask=img.split()[1]) + img = background + elif img.mode != 'RGB': + img = img.convert('RGB') + + # Convert to numpy for black bar detection + img_np = np.array(img) + + # Detect black bars + top, bottom, left, right = detect_black_bars(img_np) + + # Crop if black bars detected + if any([top > 0, bottom < img_np.shape[0] - 1, + left > 0, right < img_np.shape[1] - 1]): + img = img.crop((left, top, right, bottom)) + + # Save as configured format + if NORMALIZE_IMAGES_TO == 'png': + img.save(output_path, 'PNG', optimize=True) + else: # jpg + img.save(output_path, 'JPEG', quality=JPEG_QUALITY, optimize=True) + return True + + except Exception as e: + logger.error(f"Error converting image {input_path}: {str(e)}") + return False + +def detect_black_bars(img: np.ndarray) -> tuple[int, int, int, int]: + """Detect black bars in image + + Args: + img: numpy array of image (HxWxC) + + Returns: + Tuple of (top, bottom, left, right) crop coordinates + """ + # Convert to grayscale if needed + if len(img.shape) == 3: + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + else: + gray = img + + # Threshold to detect black regions + threshold = 20 + black_mask = gray < threshold + + # Find black bars by analyzing row/column means + row_means = np.mean(black_mask, axis=1) + col_means = np.mean(black_mask, axis=0) + + # Detect edges where black bars end (95% threshold) + black_threshold = 0.95 + + # Find top and bottom crops + top = 0 + bottom = img.shape[0] + + for i, mean in enumerate(row_means): + if mean > black_threshold: + top = i + 1 + else: + break + + for i, mean in enumerate(reversed(row_means)): + if mean > black_threshold: + bottom = img.shape[0] - i - 1 + else: + break + + # Find left and right crops + left = 0 + right = img.shape[1] + + for i, mean in enumerate(col_means): + if mean > black_threshold: + left = i + 1 + else: + break + + for i, mean in enumerate(reversed(col_means)): + if mean > black_threshold: + right = img.shape[1] - i - 1 + else: + break + + return top, bottom, left, right + diff --git a/import_service.py b/import_service.py new file mode 100644 index 0000000000000000000000000000000000000000..044c53a0128ab56c06cdee44df9c3710f7b834e9 --- /dev/null +++ b/import_service.py @@ -0,0 +1,245 @@ +import os +import shutil +import zipfile +import tempfile +import gradio as gr +from pathlib import Path +from typing import List, Dict, Optional, Tuple +from pytubefix import YouTube +import logging +from utils import is_image_file, is_video_file, add_prefix_to_caption +from image_preprocessing import normalize_image + +from config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, TRAINING_PATH, DEFAULT_PROMPT_PREFIX + +logger = logging.getLogger(__name__) + +class ImportService: + def process_uploaded_files(self, file_paths: List[str]) -> str: + """Process uploaded file (ZIP, MP4, or image) + + Args: + file_paths: File paths to the ploaded files from Gradio + + Returns: + Status message string + """ + for file_path in file_paths: + file_path = Path(file_path) + try: + original_name = file_path.name + print("original_name = ", original_name) + + # Determine file type from name + file_ext = file_path.suffix.lower() + + if file_ext == '.zip': + return self.process_zip_file(file_path) + elif file_ext == '.mp4' or file_ext == '.webm': + return self.process_mp4_file(file_path, original_name) + elif is_image_file(file_path): + return self.process_image_file(file_path, original_name) + else: + raise gr.Error(f"Unsupported file type: {file_ext}") + + except Exception as e: + raise gr.Error(f"Error processing file: {str(e)}") + + def process_image_file(self, file_path: Path, original_name: str) -> str: + """Process a single image file + + Args: + file_path: Path to the image + original_name: Original filename + + Returns: + Status message string + """ + try: + # Create a unique filename with configured extension + stem = Path(original_name).stem + target_path = STAGING_PATH / f"{stem}.{NORMALIZE_IMAGES_TO}" + + # If file already exists, add number suffix + counter = 1 + while target_path.exists(): + target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}" + counter += 1 + + # Convert to normalized format and remove black bars + success = normalize_image(file_path, target_path) + + if not success: + raise gr.Error(f"Failed to process image: {original_name}") + + # Handle caption + src_caption_path = file_path.with_suffix('.txt') + if src_caption_path.exists(): + caption = src_caption_path.read_text() + caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX) + target_path.with_suffix('.txt').write_text(caption) + + logger.info(f"Successfully stored image: {target_path.name}") + gr.Info(f"Successfully stored image: {target_path.name}") + return f"Successfully stored image: {target_path.name}" + + except Exception as e: + raise gr.Error(f"Error processing image file: {str(e)}") + + def process_zip_file(self, file_path: Path) -> str: + """Process uploaded ZIP file containing media files + + Args: + file_path: Path to the uploaded ZIP file + + Returns: + Status message string + """ + try: + video_count = 0 + image_count = 0 + + # Create temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + # Extract ZIP + extract_dir = Path(temp_dir) / "extracted" + extract_dir.mkdir() + with zipfile.ZipFile(file_path, 'r') as zip_ref: + zip_ref.extractall(extract_dir) + + # Process each file + for root, _, files in os.walk(extract_dir): + for file in files: + if file.startswith('._'): # Skip Mac metadata + continue + + file_path = Path(root) / file + + try: + if is_video_file(file_path): + # Copy video to videos_to_split + target_path = VIDEOS_TO_SPLIT_PATH / file_path.name + counter = 1 + while target_path.exists(): + target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}" + counter += 1 + shutil.copy2(file_path, target_path) + video_count += 1 + + elif is_image_file(file_path): + # Convert image and save to staging + target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}" + counter = 1 + while target_path.exists(): + target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}" + counter += 1 + if normalize_image(file_path, target_path): + image_count += 1 + + # Copy associated caption file if it exists + txt_path = file_path.with_suffix('.txt') + if txt_path.exists(): + if is_video_file(file_path): + shutil.copy2(txt_path, target_path.with_suffix('.txt')) + elif is_image_file(file_path): + shutil.copy2(txt_path, target_path.with_suffix('.txt')) + + except Exception as e: + logger.error(f"Error processing {file_path.name}: {str(e)}") + continue + + # Generate status message + parts = [] + if video_count > 0: + parts.append(f"{video_count} videos") + if image_count > 0: + parts.append(f"{image_count} images") + + if not parts: + return "No supported media files found in ZIP" + + status = f"Successfully stored {' and '.join(parts)}" + gr.Info(status) + return status + + except Exception as e: + raise gr.Error(f"Error processing ZIP: {str(e)}") + + def process_mp4_file(self, file_path: Path, original_name: str) -> str: + """Process a single video file + + Args: + file_path: Path to the file + original_name: Original filename + + Returns: + Status message string + """ + try: + # Create a unique filename + target_path = VIDEOS_TO_SPLIT_PATH / original_name + + # If file already exists, add number suffix + counter = 1 + while target_path.exists(): + stem = Path(original_name).stem + target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4" + counter += 1 + + # Copy the file to the target location + shutil.copy2(file_path, target_path) + + gr.Info(f"Successfully stored video: {target_path.name}") + return f"Successfully stored video: {target_path.name}" + + except Exception as e: + raise gr.Error(f"Error processing video file: {str(e)}") + + def download_youtube_video(self, url: str, progress=None) -> Dict: + """Download a video from YouTube + + Args: + url: YouTube video URL + progress: Optional Gradio progress indicator + + Returns: + Dict with status message and error (if any) + """ + try: + # Extract video ID and create YouTube object + yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining: + progress((1 - bytes_remaining / stream.filesize), desc="Downloading...") + if progress else None) + + video_id = yt.video_id + output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4" + + # Download highest quality progressive MP4 + if progress: + print("Getting video streams...") + progress(0, desc="Getting video streams...") + video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() + + if not video: + print("Could not find a compatible video format") + gr.Error("Could not find a compatible video format") + return "Could not find a compatible video format" + + # Download the video + if progress: + print("Starting YouTube video download...") + progress(0, desc="Starting download...") + + video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4") + + # Update UI + if progress: + print("YouTube video download complete!") + gr.Info("YouTube video download complete!") + progress(1, desc="Download complete!") + return f"Successfully downloaded video: {yt.title}" + + except Exception as e: + print(e) + gr.Error(f"Error downloading video: {str(e)}") + return f"Error downloading video: {str(e)}" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e7f99a3fffb3b735a0845057f179075b3c18a8b4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +numpy>=1.26.4 + +# to quote a-r-r-o-w/finetrainers: +# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. +torch==2.5.1 +torchvision==0.20.1 +torchao==0.6.1 + +huggingface_hub +hf_transfer>=0.1.8 +diffusers>=0.30.3 +transformers>=4.45.2 + +accelerate +bitsandbytes +peft>=0.12.0 +eva-decord==0.6.1 +wandb +pandas +sentencepiece>=0.2.0 +imageio-ffmpeg>=0.5.1 + +flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + +# for youtube video download +pytube +pytubefix + +# for scene splitting +scenedetect[opencv] + +# for llava video / captionning +pillow +pillow-avif-plugin +polars +einops +open_clip_torch +av==14.1.0 +git+https://github.com/LLaVA-VL/LLaVA-NeXT.git + +# for our frontend +gradio==5.15.0 +gradio_toggle \ No newline at end of file diff --git a/requirements_without_flash_attention.txt b/requirements_without_flash_attention.txt new file mode 100644 index 0000000000000000000000000000000000000000..907b70254843082b3ff0f21fe4ff3bf1de6d6143 --- /dev/null +++ b/requirements_without_flash_attention.txt @@ -0,0 +1,42 @@ +numpy>=1.26.4 + +# to quote a-r-r-o-w/finetrainers: +# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. +torch==2.5.1 +torchvision==0.20.1 +torchao==0.6.1 + + +huggingface_hub +hf_transfer>=0.1.8 +diffusers>=0.30.3 +transformers>=4.45.2 + +accelerate +bitsandbytes +peft>=0.12.0 +eva-decord==0.6.1 +wandb +pandas +sentencepiece>=0.2.0 +imageio-ffmpeg>=0.5.1 + +# for youtube video download +pytube +pytubefix + +# for scene splitting +scenedetect[opencv] + +# for llava video / captionning +pillow +pillow-avif-plugin +polars +einops +open_clip_torch +av==14.1.0 +git+https://github.com/LLaVA-VL/LLaVA-NeXT.git + +# for our frontend +gradio==5.15.0 +gradio_toggle \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..90d6e5122303975200f039caf15459d66fdabd98 --- /dev/null +++ b/run.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +source .venv/bin/activate + +USE_MOCK_CAPTIONING_MODEL=True python app.py \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000000000000000000000000000000000000..ba7b92bc9134b752e6b11ab9f067b9fa09ca84d0 --- /dev/null +++ b/setup.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +python -m venv .venv + +source .venv/bin/activate + +python -m pip install -r requirements.txt \ No newline at end of file diff --git a/setup_no_captions.sh b/setup_no_captions.sh new file mode 100755 index 0000000000000000000000000000000000000000..ff7b55e05b3b93ecd9af1321a93f1f0658e2d5e8 --- /dev/null +++ b/setup_no_captions.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +python -m venv .venv + +source .venv/bin/activate + +python -m pip install -r requirements_without_flash_attention.txt + +# if you require flash attention, please install it manually for your operating system + +# you can try this: +# python -m pip install wheel setuptools flash-attn --no-build-isolation --no-cache-dir \ No newline at end of file diff --git a/splitting_service.py b/splitting_service.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bda3bc4479bc0ed03186e5aafabc4b5edde878 --- /dev/null +++ b/splitting_service.py @@ -0,0 +1,242 @@ +import os +import hashlib +import shutil +from pathlib import Path +import asyncio +import tempfile +import logging +from functools import partial +from typing import Dict, List, Optional, Tuple +import gradio as gr + +from scenedetect import detect, ContentDetector, SceneManager, open_video +from scenedetect.video_splitter import split_video_ffmpeg + +from config import TRAINING_PATH, STORAGE_PATH, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX + +from image_preprocessing import detect_black_bars +from video_preprocessing import remove_black_bars +from utils import extract_scene_info, is_video_file, is_image_file, add_prefix_to_caption + +logger = logging.getLogger(__name__) + +class SplittingService: + def __init__(self): + # Track processing status + self.processing = False + self._current_file: Optional[str] = None + self._scene_counts: Dict[str, int] = {} + self._processing_status: Dict[str, str] = {} + + def compute_file_hash(self, file_path: Path) -> str: + """Compute SHA-256 hash of file""" + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + # Read file in chunks to handle large files + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + def rename_with_hash(self, video_path: Path) -> Tuple[Path, str]: + """Rename video and caption files using hash + + Args: + video_path: Path to video file + + Returns: + Tuple of (new video path, hash) + """ + # Compute hash + file_hash = self.compute_file_hash(video_path) + + # Rename video file + new_video_path = video_path.parent / f"{file_hash}{video_path.suffix}" + video_path.rename(new_video_path) + + # Rename caption file if exists + caption_path = video_path.with_suffix('.txt') + if caption_path.exists(): + new_caption_path = caption_path.parent / f"{file_hash}.txt" + caption_path.rename(new_caption_path) + + return new_video_path, file_hash + + async def process_video(self, video_path: Path, enable_splitting: bool) -> int: + """Process a single video file to detect and split scenes""" + try: + self._processing_status[video_path.name] = f'Processing video "{video_path.name}"...' + + parent_caption_path = video_path.with_suffix('.txt') + # Create output path for split videos + base_name, _ = extract_scene_info(video_path.name) + # Create temporary directory for preprocessed video + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) / f"preprocessed_{video_path.name}" + + # Try to remove black bars + was_cropped = await asyncio.get_event_loop().run_in_executor( + None, + remove_black_bars, + video_path, + temp_path + ) + + # Use preprocessed video if cropping was done, otherwise use original + process_path = temp_path if was_cropped else video_path + + # Detect scenes if splitting is enabled + if enable_splitting: + video = open_video(str(process_path)) + scene_manager = SceneManager() + scene_manager.add_detector(ContentDetector()) + scene_manager.detect_scenes(video, show_progress=False) + scenes = scene_manager.get_scene_list() + else: + scenes = [] + + num_scenes = len(scenes) + + + + if not scenes: + print(f'video "{video_path.name}" is already a single-scene clip') + + # captioning is only required if some information is missing + + if parent_caption_path.exists(): + # if it's a single scene with a caption, we can directly promote it to the training/ dir + #output_video_path = TRAINING_VIDEOS_PATH / f"{base_name}___{1:03d}.mp4" + # WELL ACTUALLY, NOT. The training videos dir removes a lot of thing, + # so it has to stay a "last resort" thing + output_video_path = STAGING_PATH / f"{base_name}___{1:03d}.mp4" + + shutil.copy2(process_path, output_video_path) + + shutil.copy2(parent_caption_path, output_video_path.with_suffix('.txt')) + parent_caption_path.unlink() + else: + # otherwise it needs to go through the normal captioning process + output_video_path = STAGING_PATH / f"{base_name}___{1:03d}.mp4" + shutil.copy2(process_path, output_video_path) + + + else: + print(f'video "{video_path.name}" contains {num_scenes} scenes') + + # in this scenario, there are multiple subscenes + # even if we have a parent caption, we must caption each of them individually + # the first step is to preserve the parent caption for later use + if parent_caption_path.exists(): + output_caption_path = STAGING_PATH / f"{base_name}.txt" + shutil.copy2(parent_caption_path, output_caption_path) + parent_caption_path.unlink() + + + output_template = str(STAGING_PATH / f"{base_name}___$SCENE_NUMBER.mp4") + + # Split video into scenes using the preprocessed video if it exists + await asyncio.get_event_loop().run_in_executor( + None, + lambda: split_video_ffmpeg( + str(process_path), + scenes, + output_file_template=output_template, + show_progress=False + ) + ) + + # Update scene count and status + crop_status = " (black bars removed)" if was_cropped else "" + self._scene_counts[video_path.name] = num_scenes + self._processing_status[video_path.name] = f"{num_scenes} scenes{crop_status}" + + # Delete original video + video_path.unlink() + + if num_scenes: + gr.Info(f"Extracted {num_scenes} clips from {video_path.name}{crop_status}") + else: + gr.Info(f"Imported {video_path.name}{crop_status}") + + return num_scenes + + except Exception as e: + self._scene_counts[video_path.name] = 0 + self._processing_status[video_path.name] = f"Error: {str(e)}" + raise gr.Error(f"Error processing video {video_path}: {str(e)}") + + def get_scene_count(self, video_name: str) -> Optional[int]: + """Get number of detected scenes for a video + + Returns None if video hasn't been scanned + """ + return self._scene_counts.get(video_name) + + def get_current_file(self) -> Optional[str]: + """Get name of file currently being processed""" + return self._current_file + + def is_processing(self) -> bool: + """Check if background processing is running""" + return self.processing + + async def start_processing(self, enable_splitting: bool) -> None: + """Start background processing of unprocessed videos""" + if self.processing: + return + + self.processing = True + try: + # Process each video + for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"): + self._current_file = video_file.name + await self.process_video(video_file, enable_splitting) + + finally: + self.processing = False + self._current_file = None + + def get_processing_status(self, video_name: str) -> str: + """Get processing status for a video + + Args: + video_name: Name of the video file + + Returns: + Status string for the video + """ + if video_name in self._processing_status: + return self._processing_status[video_name] + return "not processed" + + def list_unprocessed_videos(self) -> List[List[str]]: + """List all unprocessed and processed videos with their status. + Images will be ignored. + + Returns: + List of lists containing [name, status] for each video + """ + videos = [] + + # Track processed videos by their base names + processed_videos = {} + for clip_path in STAGING_PATH.glob("*.mp4"): + base_name = clip_path.stem.rsplit('___', 1)[0] + '.mp4' + if base_name in processed_videos: + processed_videos[base_name] += 1 + else: + processed_videos[base_name] = 1 + + # List only video files in processing queue + for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"): + if is_video_file(video_file): # Only include video files + status = self.get_processing_status(video_file.name) + videos.append([video_file.name, status]) + + # Add processed videos + for video_name, clip_count in processed_videos.items(): + if not (VIDEOS_TO_SPLIT_PATH / video_name).exists(): + status = f"Processed ({clip_count} clips)" + videos.append([video_name, status]) + + return sorted(videos, key=lambda x: (x[1] != "Processing...", x[0].lower())) diff --git a/tests/scripts/dummy_cogvideox_lora.sh b/tests/scripts/dummy_cogvideox_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..8ac3d74117dbd215ce913a6a21df1eaad5d6a894 --- /dev/null +++ b/tests/scripts/dummy_cogvideox_lora.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +GPU_IDS="0,1" +DATA_ROOT="$ROOT_DIR/video-dataset-disney" +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +OUTPUT_DIR="cogvideox" +ID_TOKEN="BW_STYLE" + +# Model arguments +model_cmd="--model_name cogvideox \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b" + +# Dataset arguments +dataset_cmd="--data_root $DATA_ROOT \ + --video_column $VIDEO_COLUMN \ + --caption_column $CAPTION_COLUMN \ + --id_token $ID_TOKEN \ + --video_resolution_buckets 49x480x720 \ + --caption_dropout_p 0.05" + +# Dataloader arguments +dataloader_cmd="--dataloader_num_workers 0 --precompute_conditions" + +# Training arguments +training_cmd="--training_type lora \ + --seed 42 \ + --batch_size 1 \ + --precompute_conditions \ + --train_steps 10 \ + --rank 128 \ + --lora_alpha 128 \ + --target_modules to_q to_k to_v to_out.0 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --checkpointing_steps 5 \ + --checkpointing_limit 2 \ + --resume_from_checkpoint=latest \ + --enable_slicing \ + --enable_tiling" + +# Optimizer arguments +optimizer_cmd="--optimizer adamw \ + --lr 3e-5 \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 1e-4 \ + --epsilon 1e-8 \ + --max_grad_norm 1.0" + +# Validation arguments +validation_prompts=$(cat < + + + + + +**Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)! + +## Quickstart + +Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`. + +Then download a dataset: + +```bash +# install `huggingface_hub` +huggingface-cli download \ + --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \ + --local-dir video-dataset-disney +``` + +Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice): + +```bash +# For LoRA finetuning of the text-to-video CogVideoX models +./train_text_to_video_lora.sh + +# For full finetuning of the text-to-video CogVideoX models +./train_text_to_video_sft.sh + +# For LoRA finetuning of the image-to-video CogVideoX models +./train_image_to_video_lora.sh +``` + +Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: + +```diff +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained( + "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 +).to("cuda") ++ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora") ++ pipe.set_adapters(["cogvideox-lora"], [1.0]) + +video = pipe("").frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details): + +```python +from diffusers import CogVideoXImageToVideoPipeline + +pipe = CogVideoXImageToVideoPipeline.from_pretrained( + "THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16 +).to("cuda") + +# ... + +del pipe.transformer.patch_embed.pos_embedding +pipe.transformer.patch_embed.use_learned_positional_embeddings = False +pipe.transformer.config.use_learned_positional_embeddings = False +``` + +You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py). + +Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible. + +## Prepare Dataset and Training + +Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example. + +- Configure environment variables as per your choice: + + ```bash + export TORCH_LOGS="+dynamo,recompiles,graph_breaks" + export TORCHDYNAMO_VERBOSE=1 + export WANDB_MODE="offline" + export NCCL_P2P_DISABLE=1 + export TORCH_NCCL_ENABLE_MONITORING=0 + ``` + +- Configure which GPUs to use for training: `GPU_IDS="0,1"` + +- Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example: + + ```bash + LEARNING_RATES=("1e-4" "1e-3") + LR_SCHEDULES=("cosine_with_restarts") + OPTIMIZERS=("adamw" "adam") + MAX_TRAIN_STEPS=("3000") + ``` + +- Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`. + +- Specify the absolute paths and columns/files for captions and videos. + + ```bash + DATA_ROOT="/path/to/my/datasets/video-dataset-disney" + CAPTION_COLUMN="prompt.txt" + VIDEO_COLUMN="videos.txt" + ``` + +- Launch experiments sweeping different hyperparameters: + ``` + for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done + done + ``` + + To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`. + +Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below. + +## Memory requirements + + + + + + + + + + + + + + + + + + + + + + + + + +
CogVideoX LoRA Finetuning
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
CogVideoX Full Finetuning
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
+ +Supported and verified memory optimizations for training include: + +- `CPUOffloadOptimizer` from [`torchao`](https://github.com/pytorch/ao). You can read about its capabilities and limitations [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload). In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` or applying `torch.compile` on the optimizer step. Additionally, it is recommended not to `torch.compile` your model for training. Gradient clipping and accumulation is not supported yet either. +- Low-bit optimizers from [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers). TODO: to test and make [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) ones work +- DeepSpeed Zero2: Since we rely on `accelerate`, follow [this guide](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) to configure your `accelerate` installation to enable training with DeepSpeed Zero2 optimizations. + +> [!IMPORTANT] +> The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs. +> +> If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offload`. + +### LoRA finetuning + +> [!NOTE] +> The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly. +> +> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> or provide a URL to a valid and accessible image. + +
+ AdamW + +**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +With `train_batch_size = 1`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 | + +With `train_batch_size = 4`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 | + +
+ +
+ AdamW (8-bit bitsandbytes) + +**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +With `train_batch_size = 1`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 | +| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 | +| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 | + +With `train_batch_size = 4`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 | + +
+ +
+ AdamW + CPUOffloadOptimizer (with gradient offloading) + +**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +With `train_batch_size = 1`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 | +| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 | + +With `train_batch_size = 4`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 | + +
+ +
+ DeepSpeed (AdamW + CPU/Parameter offloading) + +**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100. + +With `train_batch_size = 1`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 | +| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 | + +With `train_batch_size = 4`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 | +| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 | + +
+ +### Full finetuning + +> [!NOTE] +> The memory requirements for image-to-video full finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly. +> +> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> or provide a URL to a valid and accessible image. + +> [!NOTE] +> Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +
+ AdamW + +With `train_batch_size = 1`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +With `train_batch_size = 4`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +
+ +
+ AdamW (8-bit bitsandbytes) + +With `train_batch_size = 1`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 | +| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 | + +With `train_batch_size = 4`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 | +| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 | + +
+ +
+ AdamW + CPUOffloadOptimizer (with gradient offloading) + +With `train_batch_size = 1`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 | +| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 | + +With `train_batch_size = 4`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 | +| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 | + +
+ +
+ DeepSpeed (AdamW + CPU/Parameter offloading) + +**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100. + +With `train_batch_size = 1`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 | +| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 | + +With `train_batch_size = 4`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 | +| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 | + +
+ +> [!NOTE] +> - `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script. +> +> - `memory_before_validation` is the true indicator of the peak memory required for training if you choose to not perform validation/testing. + + + + + + + + +
Slaying OOMs with PyTorch
+ +## TODOs + +- [x] Make scripts compatible with DDP +- [ ] Make scripts compatible with FSDP +- [x] Make scripts compatible with DeepSpeed +- [ ] vLLM-powered captioning script +- [x] Multi-resolution/frame support in `prepare_dataset.py` +- [ ] Analyzing traces for potential speedups and removing as many syncs as possible +- [x] Test scripts with memory-efficient optimizer from bitsandbytes +- [x] Test scripts with CPUOffloadOptimizer, etc. +- [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao)) +- [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out) +- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100) + +> [!IMPORTANT] +> Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training. diff --git a/training/README_zh.md b/training/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..f18a62218a1b3f3d3b388cf1dfcb29c273b74326 --- /dev/null +++ b/training/README_zh.md @@ -0,0 +1,455 @@ +# CogVideoX Factory 🧪 + +[Read in English](./README.md) + +在 24GB GPU 内存下对 Cog 系列视频模型进行微调以实现自定义视频生成,支持多分辨率 ⚡️📼 + + + + + +
+ +## 快速开始 + +克隆此仓库并确保安装了相关依赖:`pip install -r requirements.txt`。 + +接着下载数据集: + +``` +# 安装 `huggingface_hub` +huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney +``` + +然后启动 LoRA 微调进行文本到视频的生成(根据您的选择修改不同的超参数、数据集根目录以及其他配置选项): + +``` +# 对 CogVideoX 模型进行文本到视频的 LoRA 微调 +./train_text_to_video_lora.sh + +# 对 CogVideoX 模型进行文本到视频的完整微调 +./train_text_to_video_sft.sh + +# 对 CogVideoX 模型进行图像到视频的 LoRA 微调 +./train_image_to_video_lora.sh +``` + +假设您的 LoRA 已保存并推送到 HF Hub,并命名为 `my-awesome-name/my-awesome-lora`,现在我们可以使用微调模型进行推理: + +``` +import torch +from diffusers import CogVideoXPipeline +from diffusers import export_to_video + +pipe = CogVideoXPipeline.from_pretrained( + "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 +).to("cuda") ++ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"]) ++ pipe.set_adapters(["cogvideox-lora"], [1.0]) + +video = pipe("").frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。 + +**注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装 +diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。 + +以下我们提供了更多探索此仓库选项的额外部分。所有这些都旨在尽可能降低内存需求,使视频模型的微调变得更易于访问。 + +## 训练 + +在开始训练之前,请你检查是否按照[数据集规范](assets/dataset_zh.md)准备好了数据集。 我们提供了适用于文本到视频 (text-to-video) 和图像到视频 (image-to-video) 生成的训练脚本,兼容 [CogVideoX 模型家族](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce)。训练可以通过 `train*.sh` 脚本启动,具体取决于你想要训练的任务。让我们以文本到视频的 LoRA 微调为例。 + +- 根据你的需求配置环境变量: + + ``` + export TORCH_LOGS="+dynamo,recompiles,graph_breaks" + export TORCHDYNAMO_VERBOSE=1 + export WANDB_MODE="offline" + export NCCL_P2P_DISABLE=1 + export TORCH_NCCL_ENABLE_MONITORING=0 + ``` + +- 配置用于训练的 GPU:`GPU_IDS="0,1"` + +- 选择训练的超参数。让我们以学习率和优化器类型的超参数遍历为例: + + ``` + LEARNING_RATES=("1e-4" "1e-3") + LR_SCHEDULES=("cosine_with_restarts") + OPTIMIZERS=("adamw" "adam") + MAX_TRAIN_STEPS=("3000") + ``` + +- 选择用于训练的 Accelerate 配置文件:`ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"` + 。我们在 `accelerate_configs/` 目录中提供了一些默认配置 - 单 GPU 编译/未编译、2x GPU DDP、DeepSpeed + 等。你也可以使用 `accelerate config --config_file my_config.yaml` 自定义配置文件。 + +- 指定字幕和视频的绝对路径以及列/文件。 + + ``` + DATA_ROOT="/path/to/my/datasets/video-dataset-disney" + CAPTION_COLUMN="prompt.txt" + VIDEO_COLUMN="videos.txt" + ``` + +- 运行实验,遍历不同的超参数: + ``` + for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done + done + ``` + +要了解不同参数的含义,你可以查看 [args](./training/args.py) 文件,或者使用 `--help` 运行训练脚本。 + +注意:训练脚本尚未在 MPS 上测试,因此性能和内存要求可能与下面的 CUDA 报告差异很大。 + +## 内存需求 + + + + + + + + + + + + + + + + + + + + + + + + + +
CogVideoX LoRA 微调
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
CogVideoX 全量微调
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
+ +支持和验证的训练内存优化包括: + +- `CPUOffloadOptimizer` 来自 [`torchao`](https://github.com/pytorch/ao) + 。你可以在[这里](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) + 阅读它的能力和局限性。简而言之,它允许你将可训练参数和梯度存储在 CPU 中,从而在 CPU 上进行优化步骤。这需要快速的 CPU + 优化器,如 `torch.optim.AdamW(fused=True)`,或者在优化步骤中应用 `torch.compile` + 。此外,建议不要在训练时对模型应用 `torch.compile`。梯度裁剪和累积目前还不支持。 +- 来自 [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers) + 的低位优化器。TODO:测试并使 [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) 能正常工作。 +- DeepSpeed Zero2:由于我们依赖 `accelerate` + ,请按照[此指南](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) 配置 `accelerate` 以启用 DeepSpeed + Zero2 优化训练。 + +> [!重要提示] +> 内存需求是运行 `training/prepare_dataset.py` +> +后报告的,该脚本将视频和字幕转换为潜在向量和嵌入。在训练期间,我们直接加载这些潜在向量和嵌入,不需要VAE或T5文本编码器。然而,如果执行验证/测试,则必须加载这些模块,并且会增加所需内存的数量。不进行验证/测试可以节省大量内存,这些内存可以用于较小显存的GPU上专注于训练。 +> +> 如果选择运行验证/测试,可以通过指定 `--enable_model_cpu_offload` 来为较低显存的GPU节省一些内存。 + +### LoRA微调 + +> [!重要提示] +> 图像到视频的LoRA微调的内存需求与文本到视频上的 `THUDM/CogVideoX-5b` 类似,因此没有明确报告。 +> +> 此外,为了准备I2V微调的测试图像,可以通过修改脚本实时生成它们,或使用以下命令从训练数据中提取一些帧: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> 或提供一个有效且可访问的图像URL。 + +
+ AdamW + +**注意:** 尝试在没有梯度检查点的情况下运行 CogVideoX-5b 即使在 A100(80 GB)上也会导致 OOM(内存不足)错误,因此内存需求尚未列出。 + +当 `train_batch_size = 1` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 | + +当 `train_batch_size = 4` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 | + +
+ +
+ AdamW (8-bit bitsandbytes) + +**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。 + +当 `train_batch_size = 1` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 | +| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 | +| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 | + +当 `train_batch_size = 4` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 | + +
+ +
+ AdamW + CPUOffloadOptimizer (with gradient offloading) + +**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。 + +当 `train_batch_size = 1` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 | +| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 | + +当 `train_batch_size = 4` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 | + +
+ +
+ DeepSpeed (AdamW + CPU/Parameter offloading) + +**注意:** 结果是在启用梯度检查点的情况下,使用 2x A100 运行时记录的。 + +当 `train_batch_size = 1` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 | +| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 | + +当 `train_batch_size = 4` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 | +| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 | + +
+ +### Full finetuning + +> [!注意] +> 图像到视频的完整微调内存需求与 `THUDM/CogVideoX-5b` 的文本到视频微调相似,因此没有单独列出。 +> +> 此外,要准备用于 I2V 微调的测试图像,你可以通过修改脚本实时生成图像,或者从你的训练数据中提取一些帧: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> 或提供一个有效且可访问的图像 URL。 + +> [!注意] +> 在没有使用梯度检查点的情况下运行完整微调,即使是在 A100(80GB)上,也会出现 OOM(内存不足)错误,因此未列出内存需求。 + +
+ AdamW + +当 `train_batch_size = 1` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +当 `train_batch_size = 4` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +
+ +
+ AdamW (8-bit 量化) + +当 `train_batch_size = 1` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 | +| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 | + +当 `train_batch_size = 4` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 | +| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 | + +
+ +
+ AdamW + CPUOffloadOptimizer(带有梯度卸载) + +当 `train_batch_size = 1` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 | +| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 | + +当 `train_batch_size = 4` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 | +| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 | + +
+ +
+ DeepSpeed(AdamW + CPU/参数卸载) + +**注意:** 结果是在启用 `gradient_checkpointing`(梯度检查点)功能,并在 2 台 A100 显卡上运行时报告的。 + +当 `train_batch_size = 1` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 | +| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 | + +当 `train_batch_size = 4` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 | +| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 | + +
+ +> [!注意] +> - `memory_after_validation`(验证后内存) 表示训练所需的峰值内存。这是因为除了存储训练过程中需要的激活、参数和梯度之外,还需要加载 + VAE 和文本编码器到内存中,并且执行推理操作也会消耗一定内存。为了减少训练所需的总内存,您可以选择在训练脚本中不执行验证/测试。 +> +> - 如果选择不进行验证/测试,`memory_before_validation`(验证前内存) 才是训练所需内存的真实指示器。 + + + + + + + + +
Slaying OOMs with PyTorch
+ +## 待办事项 + +- [x] 使脚本兼容 DDP +- [ ] 使脚本兼容 FSDP +- [x] 使脚本兼容 DeepSpeed +- [ ] 基于 vLLM 的字幕脚本 +- [x] 在 `prepare_dataset.py` 中支持多分辨率/帧数 +- [ ] 分析性能瓶颈并尽可能减少同步操作 +- [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法 +- [x] 使用 bitsandbytes 的节省内存优化器测试脚本 +- [x] 使用 CPUOffloadOptimizer 等测试脚本 +- [ ] 使用 torchao 量化和低位内存优化器测试脚本(目前在 AdamW(8/4-bit torchao)上报错) +- [ ] 使用 AdamW(8-bit bitsandbytes)+ CPUOffloadOptimizer(带有梯度卸载)的测试脚本(目前报错) +- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (与作者合作支持反向传播,并针对 A100 进行优化) + +> [!重要] +> 由于我们的目标是使脚本尽可能节省内存,因此我们不保证支持多 GPU 训练。 \ No newline at end of file diff --git a/training/cogvideox/__init__.py b/training/cogvideox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/cogvideox/args.py b/training/cogvideox/args.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fed7da6d85df5384bd7e29a7786e15616967ca --- /dev/null +++ b/training/cogvideox/args.py @@ -0,0 +1,484 @@ +import argparse + + +def _get_model_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + +def _get_dataset_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--dataset_file", + type=str, + default=None, + help=("Path to a CSV file if loading prompts/video paths using this format."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--height_buckets", + nargs="+", + type=int, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--width_buckets", + nargs="+", + type=int, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--frame_buckets", + nargs="+", + type=int, + default=[49], + help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53" + ) + parser.add_argument( + "--load_tensors", + action="store_true", + help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.", + ) + parser.add_argument( + "--random_flip", + type=float, + default=None, + help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + + +def _get_validation_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=None, + help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + default=False, + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", + ) + + +def _get_training_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. " + "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this " + "argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-sft", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", + type=int, + default=49, + help="All input videos will be truncated to these many frames.", + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability when finetuning image-to-video.", + ) + parser.add_argument( + "--ignore_learned_positional_embeddings", + action="store_true", + default=False, + help=( + "Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting " + "should be used when performing multi-resolution training, because CogVideoX-I2V does not support it " + "otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why." + ), + ) + + +def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy", "came"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit", + action="store_true", + help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.", + ) + parser.add_argument( + "--use_4bit", + action="store_true", + help="Whether or not to use 4-bit optimizers from `torchao`.", + ) + parser.add_argument( + "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers." + ) + parser.add_argument( + "--beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta2", + type=float, + default=0.95, + help="The beta2 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument( + "--prodigy_decouple", + action="store_true", + help="Use AdamW style decoupled weight decay.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for optimizer.", + ) + parser.add_argument( + "--epsilon", + type=float, + default=1e-8, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--prodigy_use_bias_correction", + action="store_true", + help="Turn on Adam's bias correction.", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + parser.add_argument( + "--use_cpu_offload_optimizer", + action="store_true", + help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.", + ) + parser.add_argument( + "--offload_gradients", + action="store_true", + help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.", + ) + + +def _get_configuration_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--nccl_timeout", + type=int, + default=600, + help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") + + _get_model_args(parser) + _get_dataset_args(parser) + _get_training_args(parser) + _get_validation_args(parser) + _get_optimizer_args(parser) + _get_configuration_args(parser) + + return parser.parse_args() diff --git a/training/cogvideox/cogvideox_image_to_video_lora.py b/training/cogvideox/cogvideox_image_to_video_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c6be5aa14a67381cd5bd9116866eec801464d4 --- /dev/null +++ b/training/cogvideox/cogvideox_image_to_video_lora.py @@ -0,0 +1,1016 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import random +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX LoRA Finetune + + + +## Model description + +This is a lora finetune of the CogVideoX model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +import torch +from diffusers import CogVideoXImageToVideoPipeline +from diffusers.utils import export_to_video, load_image + +pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +image = load_image("/path/to/image.png") +video = pipe(image=image, prompt="{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "image-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXImageToVideoPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +def run_validation( + args: Dict[str, Any], + accelerator: Accelerator, + transformer, + scheduler, + model_config: Dict[str, Any], + weight_dtype: torch.dtype, +) -> None: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + # These changes will also be required when trying to run inference with the trained lora + if args.ignore_learned_positional_embeddings: + del transformer.patch_embed.pos_embedding + transformer.patch_embed.use_learned_positional_embeddings = False + transformer.config.use_learned_positional_embeddings = False + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model = unwrap_model(accelerator, model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + CogVideoXImageToVideoPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + transformer_ = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + "image_to_video": True, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + images = batch["images"].to(accelerator.device, non_blocking=True) + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image_noise_sigma = torch.normal( + mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype + ) + image_noise_sigma = torch.exp(image_noise_sigma) + noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = vae.encode(noisy_images).latent_dist + + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + image_latent_dist = DiagonalGaussianDistribution(images) + latent_dist = DiagonalGaussianDistribution(videos) + + image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR + image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = latent_dist.sample() * VAE_SCALING_FACTOR + video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(video_latents) + batch_size, num_frames, num_channels, height, width = video_latents.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=accelerator.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + + ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None, + ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0) + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = video_latents + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # Checkpointing + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Validation + should_run_validation = args.validation_prompt is not None and ( + args.validation_steps is not None and global_step % args.validation_steps == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + should_run_validation = args.validation_prompt is not None and ( + args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXImageToVideoPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/cogvideox/cogvideox_image_to_video_sft.py b/training/cogvideox/cogvideox_image_to_video_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..c40b6803e24e1272aef72fc4c494c1356093dea6 --- /dev/null +++ b/training/cogvideox/cogvideox_image_to_video_sft.py @@ -0,0 +1,947 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import random +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType, init_empty_weights +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX Full Finetune + + + +## Model description + +This is a full finetune of the CogVideoX model `{base_model}`. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "image-to-video", + "diffusers-training", + "diffusers", + "cogvideox", + "cogvideox-diffusers", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXImageToVideoPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +def run_validation( + args: Dict[str, Any], + accelerator: Accelerator, + transformer, + scheduler, + model_config: Dict[str, Any], + weight_dtype: torch.dtype, +) -> None: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + if args.ignore_learned_positional_embeddings: + del transformer.patch_embed.pos_embedding + transformer.patch_embed.use_learned_positional_embeddings = False + transformer.config.use_learned_positional_embeddings = False + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + transformer.requires_grad_(True) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model = unwrap_model(accelerator, model) + model.save_pretrained( + os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB" + ) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + transformer_ = None + init_under_meta = False + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + with init_empty_weights(): + transformer_ = CogVideoXTransformer3DModel.from_config( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + init_under_meta = True + + load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer")) + transformer_.register_to_config(**load_model.config) + transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta) + del load_model + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + cast_training_params([transformer], dtype=torch.float32) + + transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + "image_to_video": True, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-sft" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + images = batch["images"].to(accelerator.device, non_blocking=True) + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image_noise_sigma = torch.normal( + mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype + ) + image_noise_sigma = torch.exp(image_noise_sigma) + noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = vae.encode(noisy_images).latent_dist + + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + image_latent_dist = DiagonalGaussianDistribution(images) + latent_dist = DiagonalGaussianDistribution(videos) + + image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR + image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = latent_dist.sample() * VAE_SCALING_FACTOR + video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(video_latents) + batch_size, num_frames, num_channels, height, width = video_latents.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=accelerator.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None, + ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0) + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = video_latents + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # Checkpointing + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Validation + should_run_validation = args.validation_prompt is not None and ( + args.validation_steps is not None and global_step % args.validation_steps == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + should_run_validation = args.validation_prompt is not None and ( + args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + + transformer.save_pretrained( + os.path.join(args.output_dir, "transformer"), + safe_serialization=True, + max_shard_size="5GB", + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/cogvideox/cogvideox_text_to_video_lora.py b/training/cogvideox/cogvideox_text_to_video_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f2c5d5adad7976e2ed257bfc3761b115a10fef --- /dev/null +++ b/training/cogvideox/cogvideox_text_to_video_lora.py @@ -0,0 +1,955 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) # isort:skip + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX LoRA Finetune + + + +## Model description + +This is a lora finetune of the CogVideoX model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model = unwrap_model(accelerator, model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + CogVideoXPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + transformer_ = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + latent_dist = DiagonalGaussianDistribution(videos) + + videos = latent_dist.sample() * VAE_SCALING_FACTOR + videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + model_input = videos + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=model_input.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/cogvideox/cogvideox_text_to_video_sft.py b/training/cogvideox/cogvideox_text_to_video_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..f0afb64bc1aec06fe5c9ece8c1d59587e0d9b597 --- /dev/null +++ b/training/cogvideox/cogvideox_text_to_video_sft.py @@ -0,0 +1,917 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType, init_empty_weights +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) # isort:skip + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX Full Finetune + + + +## Model description + +This is a full finetune of the CogVideoX model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("{repo_id}", torch_dtype=torch.bfloat16).to("cuda") + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, checkout the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) for CogVideoX. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "cogvideox", + "cogvideox-diffusers", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + transformer.requires_grad_(True) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model: CogVideoXTransformer3DModel + model = unwrap_model(accelerator, model) + model.save_pretrained( + os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB" + ) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + transformer_ = None + init_under_meta = False + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + with init_empty_weights(): + transformer_ = CogVideoXTransformer3DModel.from_config( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + init_under_meta = True + + load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer")) + transformer_.register_to_config(**load_model.config) + transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta) + del load_model + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-sft" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + latent_dist = DiagonalGaussianDistribution(videos) + + videos = latent_dist.sample() * VAE_SCALING_FACTOR + videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + model_input = videos + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=model_input.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=False, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + + transformer.save_pretrained( + os.path.join(args.output_dir, "transformer"), + safe_serialization=True, + max_shard_size="5GB", + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/cogvideox/dataset.py b/training/cogvideox/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec47b0b33e089cad3935d0dd7951137f27db9452 --- /dev/null +++ b/training/cogvideox/dataset.py @@ -0,0 +1,428 @@ +import random +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torchvision.transforms as TT +from accelerate.logging import get_logger +from torch.utils.data import Dataset, Sampler +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize + + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + +HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80] + + +class VideoDataset(Dataset): + def __init__( + self, + data_root: str, + dataset_file: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + max_num_frames: int = 49, + id_token: Optional[str] = None, + height_buckets: List[int] = None, + width_buckets: List[int] = None, + frame_buckets: List[int] = None, + load_tensors: bool = False, + random_flip: Optional[float] = None, + image_to_video: bool = False, + ) -> None: + super().__init__() + + self.data_root = Path(data_root) + self.dataset_file = dataset_file + self.caption_column = caption_column + self.video_column = video_column + self.max_num_frames = max_num_frames + self.id_token = f"{id_token.strip()} " if id_token else "" + self.height_buckets = height_buckets or HEIGHT_BUCKETS + self.width_buckets = width_buckets or WIDTH_BUCKETS + self.frame_buckets = frame_buckets or FRAME_BUCKETS + self.load_tensors = load_tensors + self.random_flip = random_flip + self.image_to_video = image_to_video + + self.resolutions = [ + (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets + ] + + # Two methods of loading data are supported. + # - Using a CSV: caption_column and video_column must be some column in the CSV. One could + # make use of other columns too, such as a motion score or aesthetic score, by modifying the + # logic in CSV processing. + # - Using two files containing line-separate captions and relative paths to videos. + # For a more detailed explanation about preparing dataset format, checkout the README. + if dataset_file is None: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_local_path() + else: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_csv() + + if len(self.video_paths) != len(self.prompts): + raise ValueError( + f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.video_transforms = transforms.Compose( + [ + transforms.RandomHorizontalFlip(random_flip) + if random_flip + else transforms.Lambda(self.identity_transform), + transforms.Lambda(self.scale_transform), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + @staticmethod + def identity_transform(x): + return x + + @staticmethod + def scale_transform(x): + return x / 255.0 + + def __len__(self) -> int: + return len(self.video_paths) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + if self.load_tensors: + image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index]) + + # This is hardcoded for now. + # The VAE's temporal compression ratio is 4. + # The VAE's spatial compression ratio is 8. + latent_num_frames = video_latents.size(1) + if latent_num_frames % 2 == 0: + num_frames = latent_num_frames * 4 + else: + num_frames = (latent_num_frames - 1) * 4 + 1 + + height = video_latents.size(2) * 8 + width = video_latents.size(3) * 8 + + return { + "prompt": prompt_embeds, + "image": image_latents, + "video": video_latents, + "video_metadata": { + "num_frames": num_frames, + "height": height, + "width": width, + }, + } + else: + image, video, _ = self._preprocess_video(self.video_paths[index]) + + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } + + def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: + if not self.data_root.exists(): + raise ValueError("Root folder for videos does not exist") + + prompt_path = self.data_root.joinpath(self.caption_column) + video_path = self.data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + + if not self.load_tensors and any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: + df = pd.read_csv(self.dataset_file) + prompts = df[self.caption_column].tolist() + video_paths = df[self.video_column].tolist() + video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + r""" + Loads a single video, or latent and prompt embedding, based on initialization parameters. + + If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here, + F, C, H and W are the frames, channels, height and width of the input video. + + If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D]. + F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length + and embedding dimension of prompt embeddings. + """ + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + + indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[: self.max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + filename_without_ext = path.name.split(".")[0] + pt_filename = f"{filename_without_ext}.pt" + + # The current path is something like: /a/b/c/d/videos/00001.mp4 + # We need to reach: /a/b/c/d/video_latents/00001.pt + image_latents_path = path.parent.parent.joinpath("image_latents") + video_latents_path = path.parent.parent.joinpath("video_latents") + embeds_path = path.parent.parent.joinpath("prompt_embeds") + + if ( + not video_latents_path.exists() + or not embeds_path.exists() + or (self.image_to_video and not image_latents_path.exists()) + ): + raise ValueError( + f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." + ) + + if self.image_to_video: + image_latent_filepath = image_latents_path.joinpath(pt_filename) + video_latent_filepath = video_latents_path.joinpath(pt_filename) + embeds_filepath = embeds_path.joinpath(pt_filename) + + if not video_latent_filepath.is_file() or not embeds_filepath.is_file(): + if self.image_to_video: + image_latent_filepath = image_latent_filepath.as_posix() + video_latent_filepath = video_latent_filepath.as_posix() + embeds_filepath = embeds_filepath.as_posix() + raise ValueError( + f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." + ) + + images = ( + torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None + ) + latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) + embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) + + return images, latents, embeds + + +class VideoDatasetWithResizing(VideoDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): + def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.video_reshape_mode = video_reshape_mode + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class BucketSampler(Sampler): + r""" + PyTorch Sampler that groups 3D data by height, width and frames. + + Args: + data_source (`VideoDataset`): + A PyTorch dataset object that is an instance of `VideoDataset`. + batch_size (`int`, defaults to `8`): + The batch size to use for training. + shuffle (`bool`, defaults to `True`): + Whether or not to shuffle the data in each batch before dispatching to dataloader. + drop_last (`bool`, defaults to `False`): + Whether or not to drop incomplete buckets of data after completely iterating over all data + in the dataset. If set to True, only batches that have `batch_size` number of entries will + be yielded. If set to False, it is guaranteed that all data in the dataset will be processed + and batches that do not have `batch_size` number of entries will also be yielded. + """ + + def __init__( + self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + ) -> None: + self.data_source = data_source + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + self.buckets = {resolution: [] for resolution in data_source.resolutions} + + self._raised_warning_for_drop_last = False + + def __len__(self): + if self.drop_last and not self._raised_warning_for_drop_last: + self._raised_warning_for_drop_last = True + logger.warning( + "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." + ) + return (len(self.data_source) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + for index, data in enumerate(self.data_source): + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] + + if self.drop_last: + return + + for fhw, bucket in list(self.buckets.items()): + if len(bucket) == 0: + continue + if self.shuffle: + random.shuffle(bucket) + yield bucket + del self.buckets[fhw] + self.buckets[fhw] = [] \ No newline at end of file diff --git a/training/cogvideox/prepare_dataset.py b/training/cogvideox/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12b29fa3de938b8890fb8e12511900e74666e016 --- /dev/null +++ b/training/cogvideox/prepare_dataset.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python3 + +import argparse +import functools +import json +import os +import pathlib +import queue +import traceback +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from diffusers import AutoencoderKLCogVideoX +from diffusers.training_utils import set_seed +from diffusers.utils import export_to_video, get_logger +from torch.utils.data import DataLoader +from torchvision import transforms +from tqdm import tqdm +from transformers import T5EncoderModel, T5Tokenizer + + +import decord # isort:skip + +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip + + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def check_height(x: Any) -> int: + x = int(x) + if x % 16 != 0: + raise argparse.ArgumentTypeError( + f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria." + ) + return x + + +def check_width(x: Any) -> int: + x = int(x) + if x % 16 != 0: + raise argparse.ArgumentTypeError( + f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria." + ) + return x + + +def check_frames(x: Any) -> int: + x = int(x) + if x % 4 != 0 and x % 4 != 1: + raise argparse.ArgumentTypeError( + f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria." + ) + return x + + +def get_args() -> Dict[str, Any]: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_id", + type=str, + default="THUDM/CogVideoX-2b", + help="Hugging Face model ID to use for tokenizer, text encoder and VAE.", + ) + parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.") + parser.add_argument( + "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data." + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).", + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--height_buckets", + nargs="+", + type=check_height, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--width_buckets", + nargs="+", + type=check_width, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--frame_buckets", + nargs="+", + type=check_frames, + default=[49], + ) + parser.add_argument( + "--random_flip", + type=float, + default=None, + help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument( + "--save_image_latents", + action="store_true", + help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to output directory where preprocessed videos/latents/embeddings will be saved.", + ) + parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.") + parser.add_argument( + "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings." + ) + parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.") + parser.add_argument( + "--save_latents_and_embeddings", + action="store_true", + help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.", + ) + parser.add_argument( + "--use_slicing", + action="store_true", + help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", + ) + parser.add_argument( + "--use_tiling", + action="store_true", + help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.") + parser.add_argument( + "--num_decode_threads", + type=int, + default=0, + help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="fp32", + help="Data type to use when generating latents and prompt embeddings.", + ) + parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.") + parser.add_argument( + "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts." + ) + return parser.parse_args() + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompts: List[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool = False, +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompts, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompts, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +to_pil_image = transforms.ToPILImage(mode="RGB") + + +def save_image(image: torch.Tensor, path: pathlib.Path) -> None: + image = image.to(dtype=torch.float32).clamp(-1, 1) + image = to_pil_image(image.float()) + image.save(path) + + +def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None: + video = video.to(dtype=torch.float32).clamp(-1, 1) + video = [to_pil_image(frame) for frame in video] + export_to_video(video, path, fps=fps) + + +def save_prompt(prompt: str, path: pathlib.Path) -> None: + with open(path, "w", encoding="utf-8") as file: + file.write(prompt) + + +def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None: + with open(path, "w", encoding="utf-8") as file: + file.write(json.dumps(metadata)) + + +@torch.no_grad() +def serialize_artifacts( + batch_size: int, + fps: int, + images_dir: Optional[pathlib.Path] = None, + image_latents_dir: Optional[pathlib.Path] = None, + videos_dir: Optional[pathlib.Path] = None, + video_latents_dir: Optional[pathlib.Path] = None, + prompts_dir: Optional[pathlib.Path] = None, + prompt_embeds_dir: Optional[pathlib.Path] = None, + images: Optional[torch.Tensor] = None, + image_latents: Optional[torch.Tensor] = None, + videos: Optional[torch.Tensor] = None, + video_latents: Optional[torch.Tensor] = None, + prompts: Optional[List[str]] = None, + prompt_embeds: Optional[torch.Tensor] = None, +) -> None: + num_frames, height, width = videos.size(1), videos.size(3), videos.size(4) + metadata = [{"num_frames": num_frames, "height": height, "width": width}] + + data_folder_mapper_list = [ + (images, images_dir, lambda img, path: save_image(img[0], path), "png"), + (image_latents, image_latents_dir, torch.save, "pt"), + (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"), + (video_latents, video_latents_dir, torch.save, "pt"), + (prompts, prompts_dir, save_prompt, "txt"), + (prompt_embeds, prompt_embeds_dir, torch.save, "pt"), + (metadata, videos_dir, save_metadata, "txt"), + ] + filenames = [uuid.uuid4() for _ in range(batch_size)] + + for data, folder, save_fn, extension in data_folder_mapper_list: + if data is None: + continue + for slice, filename in zip(data, filenames): + if isinstance(slice, torch.Tensor): + slice = slice.clone().to("cpu") + path = folder.joinpath(f"{filename}.{extension}") + save_fn(slice, path) + + +def save_intermediates(output_queue: queue.Queue) -> None: + while True: + try: + item = output_queue.get(timeout=30) + if item is None: + break + serialize_artifacts(**item) + + except queue.Empty: + continue + + +@torch.no_grad() +def main(): + args = get_args() + set_seed(args.seed) + + output_dir = pathlib.Path(args.output_dir) + tmp_dir = output_dir.joinpath("tmp") + + output_dir.mkdir(parents=True, exist_ok=True) + tmp_dir.mkdir(parents=True, exist_ok=True) + + # Create task queue for non-blocking serializing of artifacts + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers) + save_future = save_thread.submit(save_intermediates, output_queue) + + # Initialize distributed processing + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + # Single GPU + local_rank = 0 + world_size = 1 + rank = 0 + torch.cuda.set_device(rank) + + # Create folders where intermediate tensors from each rank will be saved + images_dir = tmp_dir.joinpath(f"images/{rank}") + image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}") + videos_dir = tmp_dir.joinpath(f"videos/{rank}") + video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}") + prompts_dir = tmp_dir.joinpath(f"prompts/{rank}") + prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}") + + images_dir.mkdir(parents=True, exist_ok=True) + image_latents_dir.mkdir(parents=True, exist_ok=True) + videos_dir.mkdir(parents=True, exist_ok=True) + video_latents_dir.mkdir(parents=True, exist_ok=True) + prompts_dir.mkdir(parents=True, exist_ok=True) + prompt_embeds_dir.mkdir(parents=True, exist_ok=True) + + weight_dtype = DTYPE_MAPPING[args.dtype] + target_fps = args.target_fps + + # 1. Dataset + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": False, + "random_flip": args.random_flip, + "image_to_video": args.save_image_latents, + } + if args.video_reshape_mode is None: + dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + original_dataset_size = len(dataset) + + # Split data among GPUs + if world_size > 1: + samples_per_gpu = original_dataset_size // world_size + start_index = rank * samples_per_gpu + end_index = start_index + samples_per_gpu + if rank == world_size - 1: + end_index = original_dataset_size # Make sure the last GPU gets the remaining data + + # Slice the data + dataset.prompts = dataset.prompts[start_index:end_index] + dataset.video_paths = dataset.video_paths[start_index:end_index] + else: + pass + + rank_dataset_size = len(dataset) + + # 2. Dataloader + def collate_fn(data): + prompts = [x["prompt"] for x in data[0]] + + images = None + if args.save_image_latents: + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + dataloader = DataLoader( + dataset, + batch_size=1, + sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # 3. Prepare models + device = f"cuda:{rank}" + + if args.save_latents_and_embeddings: + tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype + ) + text_encoder = text_encoder.to(device) + + vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype) + vae = vae.to(device) + + if args.use_slicing: + vae.enable_slicing() + if args.use_tiling: + vae.enable_tiling() + + # 4. Compute latents and embeddings and save + if rank == 0: + iterator = tqdm( + dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size + ) + else: + iterator = dataloader + + for step, batch in enumerate(iterator): + try: + images = None + image_latents = None + video_latents = None + prompt_embeds = None + + if args.save_image_latents: + images = batch["images"].to(device, non_blocking=True) + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + + videos = batch["videos"].to(device, non_blocking=True) + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + + prompts = batch["prompts"] + + # Encode videos & images + if args.save_latents_and_embeddings: + if args.use_slicing: + if args.save_image_latents: + encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)] + image_latents = torch.cat(encoded_slices) + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)] + video_latents = torch.cat(encoded_slices) + + else: + if args.save_image_latents: + image_latents = vae._encode(images) + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = vae._encode(videos) + + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + # Encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + args.max_sequence_length, + device, + weight_dtype, + requires_grad=False, + ) + + if images is not None: + images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 + + videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2 + + output_queue.put( + { + "batch_size": len(prompts), + "fps": target_fps, + "images_dir": images_dir, + "image_latents_dir": image_latents_dir, + "videos_dir": videos_dir, + "video_latents_dir": video_latents_dir, + "prompts_dir": prompts_dir, + "prompt_embeds_dir": prompt_embeds_dir, + "images": images, + "image_latents": image_latents, + "videos": videos, + "video_latents": video_latents, + "prompts": prompts, + "prompt_embeds": prompt_embeds, + } + ) + + except Exception: + print("-------------------------") + print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}") + traceback.print_exc() + print("-------------------------") + + # 5. Complete distributed processing + if world_size > 1: + dist.barrier() + dist.destroy_process_group() + + output_queue.put(None) + save_thread.shutdown(wait=True) + save_future.result() + + # 6. Combine results from each rank + if rank == 0: + print( + f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`" + ) + + # Move files from each rank to common directory + for subfolder, extension in [ + ("images", "png"), + ("image_latents", "pt"), + ("videos", "mp4"), + ("video_latents", "pt"), + ("prompts", "txt"), + ("prompt_embeds", "pt"), + ("videos", "txt"), + ]: + tmp_subfolder = tmp_dir.joinpath(subfolder) + combined_subfolder = output_dir.joinpath(subfolder) + combined_subfolder.mkdir(parents=True, exist_ok=True) + pattern = f"*.{extension}" + + for file in tmp_subfolder.rglob(pattern): + file.replace(combined_subfolder / file.name) + + # Remove temporary directories + def rmdir_recursive(dir: pathlib.Path) -> None: + for child in dir.iterdir(): + if child.is_file(): + child.unlink() + else: + rmdir_recursive(child) + dir.rmdir() + + rmdir_recursive(tmp_dir) + + # Combine prompts and videos into individual text files and single jsonl + prompts_folder = output_dir.joinpath("prompts") + prompts = [] + stems = [] + + for filename in prompts_folder.rglob("*.txt"): + with open(filename, "r") as file: + prompts.append(file.read().strip()) + stems.append(filename.stem) + + prompts_txt = output_dir.joinpath("prompts.txt") + videos_txt = output_dir.joinpath("videos.txt") + data_jsonl = output_dir.joinpath("data.jsonl") + + with open(prompts_txt, "w") as file: + for prompt in prompts: + file.write(f"{prompt}\n") + + with open(videos_txt, "w") as file: + for stem in stems: + file.write(f"videos/{stem}.mp4\n") + + with open(data_jsonl, "w") as file: + for prompt, stem in zip(prompts, stems): + video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt") + with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file: + metadata = json.loads(metadata_file.read()) + + data = { + "prompt": prompt, + "prompt_embed": f"prompt_embeds/{stem}.pt", + "image": f"images/{stem}.png", + "image_latent": f"image_latents/{stem}.pt", + "video": f"videos/{stem}.mp4", + "video_latent": f"video_latents/{stem}.pt", + "metadata": metadata, + } + file.write(json.dumps(data) + "\n") + + print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`") + + +if __name__ == "__main__": + main() diff --git a/training/cogvideox/text_encoder/__init__.py b/training/cogvideox/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09f9e8cd8466c60a9c266879223f8fe4f304a524 --- /dev/null +++ b/training/cogvideox/text_encoder/__init__.py @@ -0,0 +1 @@ +from .text_encoder import compute_prompt_embeddings diff --git a/training/cogvideox/text_encoder/text_encoder.py b/training/cogvideox/text_encoder/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9237875dd621baca5979a5f685288486ef1532d5 --- /dev/null +++ b/training/cogvideox/text_encoder/text_encoder.py @@ -0,0 +1,99 @@ +from typing import List, Optional, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str, + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool = False, +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds diff --git a/training/cogvideox/utils.py b/training/cogvideox/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d03b14217403908233405a2005acf2f8703431c3 --- /dev/null +++ b/training/cogvideox/utils.py @@ -0,0 +1,260 @@ +import gc +import inspect +from typing import Optional, Tuple, Union + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.utils.torch_utils import is_compiled_module + + +logger = get_logger(__name__) + + +def get_optimizer( + params_to_optimize, + optimizer_name: str = "adam", + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.98, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + prodigy_decouple: bool = False, + prodigy_use_bias_correction: bool = False, + prodigy_safeguard_warmup: bool = False, + use_8bit: bool = False, + use_4bit: bool = False, + use_torchao: bool = False, + use_deepspeed: bool = False, + use_cpu_offload_optimizer: bool = False, + offload_gradients: bool = False, +) -> torch.optim.Optimizer: + optimizer_name = optimizer_name.lower() + + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=learning_rate, + betas=(beta1, beta2), + eps=epsilon, + weight_decay=weight_decay, + ) + + if use_8bit and use_4bit: + raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") + + if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: + try: + import torchao + + torchao.__version__ + except ImportError: + raise ImportError( + "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." + ) + + if not use_torchao and use_4bit: + raise ValueError("4-bit Optimizers are only supported with torchao.") + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy", "came"] + if optimizer_name not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." + ) + optimizer_name = "adamw" + + if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: + raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") + + if use_8bit: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if optimizer_name == "adamw": + if use_torchao: + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit + + optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + else: + optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "adam": + if use_torchao: + from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + + optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam + else: + optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + init_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "beta3": beta3, + "eps": epsilon, + "weight_decay": weight_decay, + "decouple": prodigy_decouple, + "use_bias_correction": prodigy_use_bias_correction, + "safeguard_warmup": prodigy_safeguard_warmup, + } + + elif optimizer_name == "came": + try: + import came_pytorch + except ImportError: + raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") + + optimizer_class = came_pytorch.CAME + + init_kwargs = { + "lr": learning_rate, + "eps": (1e-30, 1e-16), + "betas": (beta1, beta2, beta3), + "weight_decay": weight_decay, + } + + if use_cpu_offload_optimizer: + from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + + if "fused" in inspect.signature(optimizer_class.__init__).parameters: + init_kwargs.update({"fused": True}) + + optimizer = CPUOffloadOptimizer( + params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs + ) + else: + optimizer = optimizer_class(params_to_optimize, **init_kwargs) + + return optimizer + + +def get_gradient_norm(parameters): + norm = 0 + for param in parameters: + if param.grad is None: + continue + local_norm = param.grad.detach().data.norm(2) + norm += local_norm.item() ** 2 + norm = norm**0.5 + return norm + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + patch_size_t: int = None, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + if patch_size_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def reset_memory(device: Union[str, torch.device]) -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.reset_accumulated_memory_stats(device) + + +def print_memory(device: Union[str, torch.device]) -> None: + memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 + max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print(f"{memory_allocated=:.3f} GB") + print(f"{max_memory_allocated=:.3f} GB") + print(f"{max_memory_reserved=:.3f} GB") + + +def unwrap_model(accelerator: Accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model diff --git a/training/mochi-1/README.md b/training/mochi-1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..13d0c391a58597aaf34a6b5579515420c74893c0 --- /dev/null +++ b/training/mochi-1/README.md @@ -0,0 +1,111 @@ +# Simple Mochi-1 finetuner + + + + + + + + + + +
Dataset Sample Test Sample
+ +Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨 + +We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation. + +**Updates** + +December 1 2024: Support for checkpoint saving and loading. + +## Getting started + +Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`. + +Download a demo dataset: + +```bash +huggingface-cli download \ + --repo-type dataset sayakpaul/video-dataset-disney-organized \ + --local-dir video-dataset-disney-organized +``` + +The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions): + +```bash +video_1.mp4 +video_1.txt -- One-paragraph description of video_1 +video_2.mp4 +video_2.txt -- One-paragraph description of video_2 +... +``` + +Then run (be sure to check the paths accordingly): + +```bash +bash prepare_dataset.sh +``` + +We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video. + +> [!IMPORTANT] +> It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements. + +Now, we're ready to fine-tune. To launch, run: + +```bash +bash train.sh +``` + +You can disable intermediate validation by: + +```diff +- --validation_prompt "..." \ +- --validation_prompt_separator ::: \ +- --num_validation_videos 1 \ +- --validation_epochs 1 \ +``` + +We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM. + +To use the LoRA checkpoint: + +```py +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +import torch + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") +pipe.load_lora_weights("path-to-lora") +pipe.enable_model_cpu_offload() + +pipeline_args = { + "prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions", + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": 480, + "width": 848, + "max_sequence_length": 256, + "output_type": "np", +} + +with torch.autocast("cuda", torch.bfloat16) + video = pipe(**pipeline_args).frames[0] +export_to_video(video) +``` + +## Known limitations + +(Contributions are welcome 🤗) + +Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below: + +* No support for distributed training. +* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support. +* No support for 8bit optimizers (but should be relatively easy to add). + +**Misc**: + +* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033). +* `embed.py` script is non-batched. diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb420e73540dee050a65a0bb62c5c8c77e8d9b0 --- /dev/null +++ b/training/mochi-1/args.py @@ -0,0 +1,268 @@ +""" +Default values taken from +https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml +when applicable. +""" + +import argparse + + +def _get_model_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--cast_dit", + action="store_true", + help="If we should cast DiT params to a lower precision.", + ) + parser.add_argument( + "--compile_dit", + action="store_true", + help="If we should compile the DiT.", + ) + + +def _get_dataset_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--caption_dropout", + type=float, + default=None, + help=("Probability to drop out captions randomly."), + ) + + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + + +def _get_validation_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + default=False, + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="FPS to use when serializing the output videos.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + ) + parser.add_argument( + "--width", + type=int, + default=848, + ) + + +def _get_training_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") + parser.add_argument( + "--lora_alpha", + type=int, + default=16, + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", + ) + parser.add_argument( + "--target_modules", + nargs="+", + type=str, + default=["to_k", "to_q", "to_v", "to_out.0"], + help="Target modules to train LoRA for.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="mochi-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=200, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=None, + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + ) + + +def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.01, + help="Weight decay to use for optimizer.", + ) + + +def _get_configuration_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.") + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") + + _get_model_args(parser) + _get_dataset_args(parser) + _get_training_args(parser) + _get_validation_args(parser) + _get_optimizer_args(parser) + _get_configuration_args(parser) + + return parser.parse_args() diff --git a/training/mochi-1/dataset_simple.py b/training/mochi-1/dataset_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc6153be09a6d4b18b0edb482897e53cab7411d --- /dev/null +++ b/training/mochi-1/dataset_simple.py @@ -0,0 +1,50 @@ +""" +Taken from +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py +""" + +from pathlib import Path + +import click +import torch +from torch.utils.data import DataLoader, Dataset + + +def load_to_cpu(x): + return torch.load(x, map_location=torch.device("cpu"), weights_only=True) + + +class LatentEmbedDataset(Dataset): + def __init__(self, file_paths, repeat=1): + self.items = [ + (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt")) + for p in file_paths + if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file() + ] + self.items = self.items * repeat + print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.") + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + latent_path, embed_path = self.items[idx] + return load_to_cpu(latent_path), load_to_cpu(embed_path) + + +@click.command() +@click.argument("directory", type=click.Path(exists=True, file_okay=False)) +def process_videos(directory): + dir_path = Path(directory) + mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")] + assert mp4_files, f"No mp4 files found" + + dataset = LatentEmbedDataset(mp4_files) + dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + + for latents, embeds in dataloader: + print([(k, v.shape) for k, v in latents.items()]) + + +if __name__ == "__main__": + process_videos() diff --git a/training/mochi-1/embed.py b/training/mochi-1/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ec35ebb061618e133f093f459df525a3cf4567b3 --- /dev/null +++ b/training/mochi-1/embed.py @@ -0,0 +1,111 @@ +""" +Adapted from: +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py +""" + +import click +import torch +import torchvision +from pathlib import Path +from diffusers import AutoencoderKLMochi, MochiPipeline +from transformers import T5EncoderModel, T5Tokenizer +from tqdm.auto import tqdm + + +def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str): + T, H, W = [int(s) for s in shape.split("x")] + assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6" + video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs") + fps = metadata["video_fps"] + video = video.permute(3, 0, 1, 2) + og_shape = video.shape + assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}" + assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}" + assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}" + if video.shape[1] > T: + video = video[:, :T] + print(f"Trimmed video from {og_shape[1]} to first {T} frames") + video = video.unsqueeze(0) + video = video.float() / 127.5 - 1.0 + video = video.to(model.device) + + assert video.ndim == 5 + + with torch.inference_mode(): + with torch.autocast("cuda", dtype=torch.bfloat16): + ldist = model._encode(video) + + torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt")) + + +@click.command() +@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) +@click.option( + "--model_id", + type=str, + help="Repo id. Should be genmo/mochi-1-preview", + default="genmo/mochi-1-preview", +) +@click.option("--shape", default="163x480x848", help="Shape of the video to encode") +@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.") +def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None: + """Process all videos and captions in a directory using a single GPU.""" + # comment out when running on unsupported hardware + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Get all video paths + video_paths = list(output_dir.glob("**/*.mp4")) + if not video_paths: + print(f"No MP4 files found in {output_dir}") + return + + text_paths = list(output_dir.glob("**/*.txt")) + if not text_paths: + print(f"No text files found in {output_dir}") + return + + # load the models + vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda") + text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder") + tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer") + pipeline = MochiPipeline.from_pretrained( + model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None + ).to("cuda") + + for idx, video_path in tqdm(enumerate(sorted(video_paths))): + print(f"Processing {video_path}") + try: + if video_path.with_suffix(".latent.pt").exists() and not overwrite: + print(f"Skipping {video_path}") + continue + + # encode videos. + encode_videos(vae, vid_path=video_path, shape=shape) + + # embed captions. + prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt") + embed_path = prompt_path.with_suffix(".embed.pt") + + if embed_path.exists() and not overwrite: + print(f"Skipping {prompt_path} - embeddings already exist") + continue + + with open(prompt_path) as f: + text = f.read().strip() + with torch.inference_mode(): + conditioning = pipeline.encode_prompt(prompt=[text]) + + conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]} + torch.save(conditioning, embed_path) + + except Exception as e: + import traceback + + traceback.print_exc() + print(f"Error processing {video_path}: {str(e)}") + + +if __name__ == "__main__": + batch_process() diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..c424b4e5913cbfc64172951573ec1d7eb578b5d5 --- /dev/null +++ b/training/mochi-1/prepare_dataset.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +GPU_ID=0 +VIDEO_DIR=video-dataset-disney-organized +OUTPUT_DIR=videos_prepared +NUM_FRAMES=37 +RESOLUTION=480x848 + +# Extract width and height from RESOLUTION +WIDTH=$(echo $RESOLUTION | cut -dx -f1) +HEIGHT=$(echo $RESOLUTION | cut -dx -f2) + +python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample + +CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT} diff --git a/training/mochi-1/requirements.txt b/training/mochi-1/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a03ceeb0ee286d80009b9e2dc39d801d18603e4c --- /dev/null +++ b/training/mochi-1/requirements.txt @@ -0,0 +1,8 @@ +peft +transformers +wandb +torch +torchvision +av==11.0.0 +moviepy==1.0.3 +click \ No newline at end of file diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..af1ce6268dbc8cc88d2b02d1b292c285ceb43b96 --- /dev/null +++ b/training/mochi-1/text_to_video_lora.py @@ -0,0 +1,592 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +from glob import glob +import math +import os +import torch.nn.functional as F +import numpy as np +from pathlib import Path +from typing import Any, Dict, Tuple, List + +import torch +import wandb +from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + + +from args import get_args # isort:skip +from dataset_simple import LatentEmbedDataset + +import sys +from utils import print_memory, reset_memory # isort:skip + + +# Taken from +# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139 +def get_cosine_annealing_lr_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, +): + def lr_lambda(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + else: + return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=30, +): + widget_dict = [] + if videos is not None and len(videos) > 0: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"final_video_{i}.mp4"}, + } + ) + + model_description = f""" +# Mochi-1 Preview LoRA Finetune + + + +## Model description + +This is a lora finetune of the Mochi-1 preview model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +import torch + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") +pipe.load_lora_weights("CHANGE_ME") +pipe.enable_model_cpu_offload() + +with torch.autocast("cuda", torch.bfloat16): + video = pipe( + prompt="CHANGE_ME", + guidance_scale=6.0, + num_inference_steps=64, + height=480, + width=848, + max_sequence_length=256, + output_type="np" + ).frames[0] +export_to_video(video) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="apache-2.0", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "mochi-1-preview", + "mochi-1-preview-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipe: MochiPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + wandb_run: str = None, + is_final_validation: bool = False, +): + print( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + phase_name = "test" if is_final_validation else "validation" + + if not args.enable_model_cpu_offload: + pipe = pipe.to("cuda") + + # run inference + generator = torch.manual_seed(args.seed) if args.seed else None + + videos = [] + with torch.autocast("cuda", torch.bfloat16, cache_enabled=False): + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=30) + video_filenames.append(filename) + + if wandb_run: + wandb.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +# Adapted from the original code: +# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578 +def cast_dit(model, dtype): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert any( + n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"] + ), f"Unexpected linear layer: {name}" + module.to(dtype=dtype) + elif isinstance(module, torch.nn.Conv2d): + module.to(dtype=dtype) + return model + + +def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path): + lora_state_dict = get_peft_model_state_dict(model) + torch.save( + { + "state_dict": lora_state_dict, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "global_step": global_step, + }, + checkpoint_path, + ) + + +class CollateFunction: + def __init__(self, caption_dropout: float = None) -> None: + self.caption_dropout = caption_dropout + + def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]: + ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0) + z = DiagonalGaussianDistribution(ldists).sample() + assert torch.isfinite(z).all() + + # Sample noise which we will add to the samples. + eps = torch.randn_like(z) + sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32) + + prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0) + prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0) + if self.caption_dropout and random.random() < self.caption_dropout: + prompt_embeds.zero_() + prompt_attention_mask = prompt_attention_mask.long() + prompt_attention_mask.zero_() + prompt_attention_mask = prompt_attention_mask.bool() + + return dict( + z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask + ) + + +def main(args): + if not torch.cuda.is_available(): + raise ValueError("Not supported without CUDA.") + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + # Handle the repository creation + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + transformer = MochiTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + + transformer.requires_grad_(False) + transformer.to("cuda") + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.cast_dit: + transformer = cast_dit(transformer, torch.bfloat16) + if args.compile_dit: + transformer.compile() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights="gaussian", + target_modules=args.target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = args.learning_rate * args.train_batch_size + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + # Prepare optimizer + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters) + optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay) + + # Dataset and DataLoader + train_vids = list(sorted(glob(f"{args.data_root}/*.mp4"))) + train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] + print(f"Found {len(train_vids)} training videos in {args.data_root}") + assert len(train_vids) > 0, f"No training data found in {args.data_root}" + + collate_fn = CollateFunction(caption_dropout=args.caption_dropout) + train_dataset = LatentEmbedDataset(train_vids, repeat=1) + train_dataloader = DataLoader( + train_dataset, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # LR scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = len(train_dataloader) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_cosine_annealing_lr_scheduler( + optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = len(train_dataloader) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + wandb_run = None + if args.report_to == "wandb": + tracker_name = args.tracker_name or "mochi-1-lora" + wandb_run = wandb.init(project=tracker_name, config=vars(args)) + + # Resume from checkpoint if specified + if args.resume_from_checkpoint: + checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu", weights_only=True) + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + if "optimizer" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + if "lr_scheduler" in checkpoint: + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + set_peft_model_state_dict(transformer, checkpoint["state_dict"]) + + print(f"Resuming from checkpoint: {args.resume_from_checkpoint}") + print(f"Resuming from global step: {global_step}") + else: + global_step = 0 + + print("===== Memory before training =====") + reset_memory("cuda") + print_memory("cuda") + + # Train! + total_batch_size = args.train_batch_size + print("***** Running training *****") + print(f" Num trainable parameters = {num_trainable_parameters}") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num epochs = {args.num_train_epochs}") + print(f" Instantaneous batch size per device = {args.train_batch_size}") + print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + print(f" Total optimization steps = {args.max_train_steps}") + + first_epoch = 0 + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=global_step, + desc="Steps", + ) + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + with torch.no_grad(): + z = batch["z"].to("cuda") + eps = batch["eps"].to("cuda") + sigma = batch["sigma"].to("cuda") + prompt_embeds = batch["prompt_embeds"].to("cuda") + prompt_attention_mask = batch["prompt_attention_mask"].to("cuda") + + sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps + ut = z - eps + + # (1 - sigma) because of + # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 + # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation. + timesteps = (1 - sigma) * scheduler.config.num_train_timesteps + + with torch.autocast("cuda", torch.bfloat16): + model_pred = transformer( + hidden_states=z_sigma, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + return_dict=False, + )[0] + assert model_pred.shape == z.shape + loss = F.mse_loss(model_pred.float(), ut.float()) + loss.backward() + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + progress_bar.update(1) + global_step += 1 + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs = {"loss": loss.detach().item(), "lr": last_lr} + progress_bar.set_postfix(**logs) + if wandb_run: + wandb_run.log(logs, step=global_step) + + if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0: + print(f"Saving checkpoint at step {global_step}") + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt") + save_checkpoint( + transformer, + optimizer, + lr_scheduler, + global_step, + checkpoint_path, + ) + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + print("===== Memory before validation =====") + print_memory("cuda") + + transformer.eval() + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": args.height, + "width": args.width, + "max_sequence_length": 256, + } + log_validation( + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + wandb_run=wandb_run, + ) + + print("===== Memory after validation =====") + print_memory("cuda") + reset_memory("cuda") + + del pipe.text_encoder + del pipe.vae + del pipe + gc.collect() + torch.cuda.empty_cache() + + transformer.train() + + transformer.eval() + transformer_lora_layers = get_peft_model_state_dict(transformer) + MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) + + # Cleanup trained models to save memory + del transformer + + gc.collect() + torch.cuda.empty_cache() + + # Final test inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + print("===== Memory before testing =====") + print_memory("cuda") + reset_memory("cuda") + + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") + pipe.set_adapters(["mochi-lora"], [lora_scaling]) + + # Run inference + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": args.height, + "width": args.width, + "max_sequence_length": 256, + } + + video = log_validation( + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + wandb_run=wandb_run, + is_final_validation=True, + ) + validation_outputs.extend(video) + + print("===== Memory after testing =====") + print_memory("cuda") + reset_memory("cuda") + torch.cuda.synchronize("cuda") + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["*.bin"], + ) + print(f"Params pushed to {repo_id}.") + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..2c378e2e1b7c8bc262ce74dbea01e8b6ed4994e4 --- /dev/null +++ b/training/mochi-1/train.sh @@ -0,0 +1,37 @@ +#!/bin/bash +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +DATA_ROOT="videos_prepared" +MODEL="genmo/mochi-1-preview" +OUTPUT_PATH="mochi-lora" + +cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \ + --pretrained_model_name_or_path $MODEL \ + --cast_dit \ + --data_root $DATA_ROOT \ + --seed 42 \ + --output_dir $OUTPUT_PATH \ + --train_batch_size 1 \ + --dataloader_num_workers 4 \ + --pin_memory \ + --caption_dropout 0.1 \ + --max_train_steps 2000 \ + --gradient_checkpointing \ + --enable_slicing \ + --enable_tiling \ + --enable_model_cpu_offload \ + --optimizer adamw \ + --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --allow_tf32 \ + --report_to wandb \ + --push_to_hub" + +echo "Running command: $cmd" +eval $cmd +echo -ne "-------------------- Finished executing script --------------------\n\n" \ No newline at end of file diff --git a/training/mochi-1/trim_and_crop_videos.py b/training/mochi-1/trim_and_crop_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6f411d9e5fc6133232405f38b0ec5d7a627765 --- /dev/null +++ b/training/mochi-1/trim_and_crop_videos.py @@ -0,0 +1,126 @@ +""" +Adapted from: +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py +""" + +from pathlib import Path +import shutil + +import click +from moviepy.editor import VideoFileClip +from tqdm import tqdm + + +@click.command() +@click.argument("folder", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_folder", type=click.Path(dir_okay=True)) +@click.option("--num_frames", "-f", type=float, default=30, help="Number of frames") +@click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution") +@click.option("--force_upsample", is_flag=True, help="Force upsample.") +def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample): + """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution""" + input_path = Path(folder) + output_path = Path(output_folder) + output_path.mkdir(parents=True, exist_ok=True) + + # Parse target resolution + target_height, target_width = map(int, resolution.split("x")) + + # Calculate duration + duration = (num_frames / 30) + 0.09 + + # Find all MP4 and MOV files + video_files = ( + list(input_path.rglob("*.mp4")) + + list(input_path.rglob("*.MOV")) + + list(input_path.rglob("*.mov")) + + list(input_path.rglob("*.MP4")) + ) + + for file_path in tqdm(video_files): + try: + relative_path = file_path.relative_to(input_path) + output_file = output_path / relative_path.with_suffix(".mp4") + output_file.parent.mkdir(parents=True, exist_ok=True) + + click.echo(f"Processing: {file_path}") + video = VideoFileClip(str(file_path)) + + # Skip if video is too short + if video.duration < duration: + click.echo(f"Skipping {file_path} as it is too short") + continue + + # Skip if target resolution is larger than input + if target_width > video.w or target_height > video.h: + if force_upsample: + click.echo( + f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video." + ) + video = video.resize(width=target_width, height=target_height) + else: + click.echo( + f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}" + ) + continue + + # First truncate duration + truncated = video.subclip(0, duration) + + # Calculate crop dimensions to maintain aspect ratio + target_ratio = target_width / target_height + current_ratio = truncated.w / truncated.h + + if current_ratio > target_ratio: + # Video is wider than target ratio - crop width + new_width = int(truncated.h * target_ratio) + x1 = (truncated.w - new_width) // 2 + final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height)) + else: + # Video is taller than target ratio - crop height + new_height = int(truncated.w / target_ratio) + y1 = (truncated.h - new_height) // 2 + final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height)) + + # Set output parameters for consistent MP4 encoding + output_params = { + "codec": "libx264", + "audio": False, # Disable audio + "preset": "medium", # Balance between speed and quality + "bitrate": "5000k", # Adjust as needed + } + + # Set FPS to 30 + final = final.set_fps(30) + + # Check for a corresponding .txt file + txt_file_path = file_path.with_suffix(".txt") + if txt_file_path.exists(): + output_txt_file = output_path / relative_path.with_suffix(".txt") + output_txt_file.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(txt_file_path, output_txt_file) + click.echo(f"Copied {txt_file_path} to {output_txt_file}") + else: + # Print warning in bold yellow with a warning emoji + click.echo( + f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m" + ) + output_txt_file = output_path / relative_path.with_suffix(".txt") + output_txt_file.parent.mkdir(parents=True, exist_ok=True) + output_txt_file.touch() + + # Write the output file + final.write_videofile(str(output_file), **output_params) + + # Clean up + video.close() + truncated.close() + final.close() + + except Exception as e: + click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True) + raise + + +if __name__ == "__main__": + truncate_videos() diff --git a/training/mochi-1/utils.py b/training/mochi-1/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76fe35c2036da2483aa431d354226ccb1c16b9bc --- /dev/null +++ b/training/mochi-1/utils.py @@ -0,0 +1,22 @@ +import gc +import inspect +from typing import Optional, Tuple, Union + +import torch + +logger = get_logger(__name__) + +def reset_memory(device: Union[str, torch.device]) -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.reset_accumulated_memory_stats(device) + + +def print_memory(device: Union[str, torch.device]) -> None: + memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 + max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print(f"{memory_allocated=:.3f} GB") + print(f"{max_memory_allocated=:.3f} GB") + print(f"{max_memory_reserved=:.3f} GB") diff --git a/training/prepare_dataset.sh b/training/prepare_dataset.sh new file mode 100755 index 0000000000000000000000000000000000000000..304786d309834e54f39d290af6eba770b30cdc03 --- /dev/null +++ b/training/prepare_dataset.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +MODEL_ID="THUDM/CogVideoX-2b" + +NUM_GPUS=8 + +# For more details on the expected data format, please refer to the README. +DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located. +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset" +HEIGHT_BUCKETS="480 720" +WIDTH_BUCKETS="720 960" +FRAME_BUCKETS="49" +MAX_NUM_FRAMES="49" +MAX_SEQUENCE_LENGTH=226 +TARGET_FPS=8 +BATCH_SIZE=1 +DTYPE=fp32 + +# To create a folder-style dataset structure without pre-encoding videos and captions +# For Image-to-Video finetuning, make sure to pass `--save_image_latents` +CMD_WITHOUT_PRE_ENCODING="\ + torchrun --nproc_per_node=$NUM_GPUS \ + training/prepare_dataset.py \ + --model_id $MODEL_ID \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --output_dir $OUTPUT_DIR \ + --height_buckets $HEIGHT_BUCKETS \ + --width_buckets $WIDTH_BUCKETS \ + --frame_buckets $FRAME_BUCKETS \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_sequence_length $MAX_SEQUENCE_LENGTH \ + --target_fps $TARGET_FPS \ + --batch_size $BATCH_SIZE \ + --dtype $DTYPE +" + +CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings" + +# Select which you'd like to run +CMD=$CMD_WITH_PRE_ENCODING + +echo "===== Running \`$CMD\` =====" +eval $CMD +echo -ne "===== Finished running script =====\n" diff --git a/training/train_image_to_video_lora.sh b/training/train_image_to_video_lora.sh new file mode 100755 index 0000000000000000000000000000000000000000..8ff0111a88018fba1fea322e3c91feb4fc23516b --- /dev/null +++ b/training/train_image_to_video_lora.sh @@ -0,0 +1,82 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4" "1e-3") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw" "adam") +MAX_TRAIN_STEPS=("3000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset +DATA_ROOT="/path/to/my/datasets/disney-dataset" +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" + +# Launch experiments with different hyperparameters +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_image_to_video_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b-I2V \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_images \"/path/to/image1.png:::/path/to/image2.png\" + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --noised_image_dropout 0.05 \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/training/train_image_to_video_sft.sh b/training/train_image_to_video_sft.sh new file mode 100755 index 0000000000000000000000000000000000000000..7cdbf338908dea16e0099bad3b7c124feba50678 --- /dev/null +++ b/training/train_image_to_video_sft.sh @@ -0,0 +1,87 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +# export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export TOKENIZERS_PARALLELISM=true +export OMP_NUM_THREADS=16 + +GPU_IDS="0,1" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw") +MAX_TRAIN_STEPS=("20000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/deepspeed.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset +DATA_ROOT="/path/to/my/datasets/video-dataset-disney" +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +MODEL_PATH="THUDM/CogVideoX1.5-5B-I2V" + +# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process. +# Launch experiments with different hyperparameters + +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="./cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE \ + --gpu_ids $GPU_IDS \ + training/cogvideox/cogvideox_image_to_video_sft.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 77 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_images \"/path/to/image1.png:::/path/to/image2.png\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --seed 42 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 77 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 2000 \ + --gradient_accumulation_steps 4 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 800 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --noised_image_dropout 0.05 \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/training/train_text_to_video_lora.sh b/training/train_text_to_video_lora.sh new file mode 100755 index 0000000000000000000000000000000000000000..e7239f56242108023280ed9533e731297edf216d --- /dev/null +++ b/training/train_text_to_video_lora.sh @@ -0,0 +1,86 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4" "1e-3") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw" "adam") +MAX_TRAIN_STEPS=("3000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset +DATA_ROOT="/path/to/my/datasets/disney-dataset" + +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +MODEL_PATH="THUDM/CogVideoX-5b" + +# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process. +# Launch experiments with different hyperparameters + +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="./cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --enable_model_cpu_offload \ + --load_tensors \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/training/train_text_to_video_sft.sh b/training/train_text_to_video_sft.sh new file mode 100755 index 0000000000000000000000000000000000000000..b4de76caa4fa035959451f440e5757e33a88c9f6 --- /dev/null +++ b/training/train_text_to_video_sft.sh @@ -0,0 +1,77 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw") +MAX_TRAIN_STEPS=("20000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Tom-and-Jerry-VideoGeneration-Dataset --local-dir /path/to/my/datasets/tom-and-jerry-dataset +DATA_ROOT="/path/to/my/datasets/tom-and-jerry-dataset" +CAPTION_COLUMN="captions.txt" +VIDEO_COLUMN="videos.txt" + +# Launch experiments with different hyperparameters +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_sft.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"Tom, the mischievous gray cat, is sprawled out on a vibrant red pillow, his body relaxed and his eyes half-closed, as if he's just woken up or is about to doze off. His white paws are stretched out in front of him, and his tail is casually draped over the edge of the pillow. The setting appears to be a cozy corner of a room, with a warm yellow wall in the background and a hint of a wooden floor. The scene captures a rare moment of tranquility for Tom, contrasting with his usual energetic and playful demeanor:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --seed 42 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 2000 \ + --gradient_accumulation_steps 4 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 800 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/training_log_parser.py b/training_log_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c29042c2cf34fd47b1b285c432c24de7332b1d5b --- /dev/null +++ b/training_log_parser.py @@ -0,0 +1,183 @@ +import re +import logging +from dataclasses import dataclass +from typing import Optional, Dict, Any +from datetime import datetime, timedelta + +logger = logging.getLogger(__name__) + +@dataclass +class TrainingState: + """Represents the current state of training""" + status: str = "idle" # idle, initializing, training, completed, error, stopped + current_step: int = 0 + total_steps: int = 0 + current_epoch: int = 0 + total_epochs: int = 0 + step_loss: float = 0.0 + learning_rate: float = 0.0 + grad_norm: float = 0.0 + memory_allocated: float = 0.0 + memory_reserved: float = 0.0 + start_time: Optional[datetime] = None + last_step_time: Optional[datetime] = None + estimated_remaining: Optional[timedelta] = None + error_message: Optional[str] = None + initialization_stage: str = "" + download_progress: float = 0.0 + + def calculate_progress(self) -> float: + """Calculate overall progress as percentage""" + if self.total_steps == 0: + return 0.0 + return (self.current_step / self.total_steps) * 100 + + def to_dict(self) -> Dict[str, Any]: + """Convert state to dictionary for UI updates""" + elapsed = str(datetime.now() - self.start_time) if self.start_time else "0:00:00" + remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..." + + return { + "status": self.status, + "progress": f"{self.calculate_progress():.1f}%", + "current_step": self.current_step, + "total_steps": self.total_steps, + "current_epoch": self.current_epoch, + "total_epochs": self.total_epochs, + "step_loss": f"{self.step_loss:.4f}", + "learning_rate": f"{self.learning_rate:.2e}", + "grad_norm": f"{self.grad_norm:.4f}", + "memory": f"{self.memory_allocated:.1f}GB allocated, {self.memory_reserved:.1f}GB reserved", + "elapsed": elapsed, + "remaining": remaining, + "initialization_stage": self.initialization_stage, + "error_message": self.error_message, + "download_progress": self.download_progress + } + +class TrainingLogParser: + """Parser for training logs with state management""" + + def __init__(self): + self.state = TrainingState() + self._last_update_time = None + + def parse_line(self, line: str) -> Optional[Dict[str, Any]]: + """Parse a single log line and update state""" + try: + # For debugging + logger.info(f"Parsing line: {line[:100]}...") + + # Training step progress line example: + # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7] + if "Training steps:" in line: + # Set status to training if we see this + self.state.status = "training" + if not self.state.start_time: + self.state.start_time = datetime.now() + + # Extract step numbers + steps_match = re.search(r"(\d+)/(\d+)", line) + if steps_match: + self.state.current_step = int(steps_match.group(1)) + self.state.total_steps = int(steps_match.group(2)) + + # Extract metrics + for pattern, attr in [ + (r"step_loss=([0-9.e-]+)", "step_loss"), + (r"lr=([0-9.e-]+)", "learning_rate"), + (r"grad_norm=([0-9.e-]+)", "grad_norm") + ]: + match = re.search(pattern, line) + if match: + setattr(self.state, attr, float(match.group(1))) + + # Calculate time estimates based on total elapsed time + now = datetime.now() + if self.state.start_time and self.state.current_step > 0: + # Calculate elapsed time and average time per step + elapsed_seconds = (now - self.state.start_time).total_seconds() + avg_time_per_step = elapsed_seconds / self.state.current_step + + # Calculate remaining time + remaining_steps = self.state.total_steps - self.state.current_step + estimated_remaining_seconds = avg_time_per_step * remaining_steps + + # Format as days, hours, minutes, seconds + days = int(estimated_remaining_seconds // (24 * 3600)) + hours = int((estimated_remaining_seconds % (24 * 3600)) // 3600) + minutes = int((estimated_remaining_seconds % 3600) // 60) + seconds = int(estimated_remaining_seconds % 60) + + # Create formatted timedelta + if days > 0: + formatted_time = f"{days}d {hours}h {minutes}m {seconds}s" + elif hours > 0: + formatted_time = f"{hours}h {minutes}m {seconds}s" + elif minutes > 0: + formatted_time = f"{minutes}m {seconds}s" + else: + formatted_time = f"{seconds}s" + + self.state.estimated_remaining = formatted_time + self.state.last_step_time = now + + logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}") + return self.state.to_dict() + + # Epoch information + # there is an issue with how epoch is reported because we display: + # Progress: 96.9%, Step: 872/900, Epoch: 12/50 + # we should probably just show the steps + epoch_match = re.search(r"Starting epoch \((\d+)/(\d+)\)", line) + if epoch_match: + self.state.current_epoch = int(epoch_match.group(1)) + self.state.total_epochs = int(epoch_match.group(2)) + logger.info(f"Updated epoch: {self.state.current_epoch}/{self.state.total_epochs}") + return self.state.to_dict() + + # Initialization stages + if "Initializing" in line: + self.state.status = "initializing" + self.state.initialization_stage = line.split("Initializing")[1].strip() + logger.info(f"Initialization stage: {self.state.initialization_stage}") + return self.state.to_dict() + + # Memory usage + if "memory_allocated" in line: + mem_match = re.search(r'"memory_allocated":\s*([0-9.]+)', line) + if mem_match: + self.state.memory_allocated = float(mem_match.group(1)) + + reserved_match = re.search(r'"memory_reserved":\s*([0-9.]+)', line) + if reserved_match: + self.state.memory_reserved = float(reserved_match.group(1)) + logger.info(f"Updated memory: allocated={self.state.memory_allocated}GB, reserved={self.state.memory_reserved}GB") + return self.state.to_dict() + + # Completion states + if "Training completed successfully" in line: + self.state.status = "completed" + logger.info("Training completed") + return self.state.to_dict() + + if any(x in line for x in ["Training process stopped", "Training stopped"]): + self.state.status = "stopped" + logger.info("Training stopped") + return self.state.to_dict() + + if "Error during training:" in line: + self.state.status = "error" + self.state.error_message = line.split("Error during training:")[1].strip() + logger.info(f"Training error: {self.state.error_message}") + return self.state.to_dict() + + except Exception as e: + logger.error(f"Error parsing line: {str(e)}") + + return None + + def reset(self): + """Reset parser state""" + self.state = TrainingState() + self._last_update_time = None \ No newline at end of file diff --git a/training_service.py b/training_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3aaea5322caa86df4bbcd8277662c99d31669fdd --- /dev/null +++ b/training_service.py @@ -0,0 +1,578 @@ +import os +import sys +import json +import time +import shutil +import gradio as gr +from pathlib import Path +from datetime import datetime +import subprocess +import signal +import psutil +import tempfile +import zipfile +import logging +import traceback +import threading +import select + +from typing import Any, Optional, Dict, List, Union, Tuple + +from huggingface_hub import upload_folder, create_repo +from config import TrainingConfig, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES +from utils import make_archive, parse_training_log, is_image_file, is_video_file +from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler('training_service.log') + ] +) +logger = logging.getLogger(__name__) + +class TrainingService: + def __init__(self): + # State and log files + self.session_file = OUTPUT_PATH / "session.json" + self.status_file = OUTPUT_PATH / "status.json" + self.pid_file = OUTPUT_PATH / "training.pid" + self.log_file = OUTPUT_PATH / "training.log" + logger.info("Training service initialized") + + def save_session(self, params: Dict) -> None: + """Save training session parameters""" + session_data = { + "timestamp": datetime.now().isoformat(), + "params": params, + "status": self.get_status() + } + with open(self.session_file, 'w') as f: + json.dump(session_data, f, indent=2) + + def load_session(self) -> Optional[Dict]: + """Load saved training session""" + if self.session_file.exists(): + try: + with open(self.session_file, 'r') as f: + return json.load(f) + except json.JSONDecodeError: + return None + return None + + def get_status(self) -> Dict: + """Get current training status""" + default_status = {'state': 'stopped', 'message': 'No training in progress'} + + if not self.status_file.exists(): + return default_status + + try: + with open(self.status_file, 'r') as f: + status = json.load(f) + + # Check if process is actually running + if self.pid_file.exists(): + with open(self.pid_file, 'r') as f: + pid = int(f.read().strip()) + if not psutil.pid_exists(pid): + # Process died unexpectedly + if status['state'] == 'running': + status['state'] = 'error' + status['message'] = 'Training process terminated unexpectedly' + self.append_log("Training process terminated unexpectedly") + else: + status['state'] = 'stopped' + status['message'] = 'Training process not found' + return status + + except (json.JSONDecodeError, ValueError): + return default_status + + def get_logs(self, max_lines: int = 100) -> str: + """Get training logs with line limit""" + if self.log_file.exists(): + with open(self.log_file, 'r') as f: + lines = f.readlines() + return ''.join(lines[-max_lines:]) + return "" + + def append_log(self, message: str) -> None: + """Append message to log file and logger""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with open(self.log_file, 'a') as f: + f.write(f"[{timestamp}] {message}\n") + logger.info(message) + + def clear_logs(self) -> None: + """Clear log file""" + if self.log_file.exists(): + self.log_file.unlink() + self.append_log("Log file cleared") + + def validate_training_config(self, config: TrainingConfig, model_type: str) -> Optional[str]: + """Validate training configuration""" + logger.info(f"Validating config for {model_type}") + + try: + # Basic validation + if not config.data_root or not Path(config.data_root).exists(): + return f"Invalid data root path: {config.data_root}" + + if not config.output_dir: + return "Output directory not specified" + + # Check for required files + videos_file = Path(config.data_root) / "videos.txt" + prompts_file = Path(config.data_root) / "prompts.txt" + + if not videos_file.exists(): + return f"Missing videos list file: {videos_file}" + if not prompts_file.exists(): + return f"Missing prompts list file: {prompts_file}" + + # Validate file counts match + video_lines = [l.strip() for l in open(videos_file) if l.strip()] + prompt_lines = [l.strip() for l in open(prompts_file) if l.strip()] + + if not video_lines: + return "No training files found" + if len(video_lines) != len(prompt_lines): + return f"Mismatch between video count ({len(video_lines)}) and prompt count ({len(prompt_lines)})" + + # Model-specific validation + if model_type == "hunyuan_video": + if config.batch_size > 2: + return "Hunyuan model recommended batch size is 1-2" + if not config.gradient_checkpointing: + return "Gradient checkpointing is required for Hunyuan model" + elif model_type == "ltx_video": + if config.batch_size > 4: + return "LTX model recommended batch size is 1-4" + + logger.info(f"Config validation passed with {len(video_lines)} training files") + return None + + except Exception as e: + logger.error(f"Error during config validation: {str(e)}") + return f"Configuration validation failed: {str(e)}" + + + def start_training(self, model_type: str, lora_rank: str, lora_alpha: str, num_epochs: int, batch_size: int, + learning_rate: float, save_iterations: int, repo_id: str) -> Tuple[str, str]: + """Start training with finetrainers""" + + self.clear_logs() + + if not model_type: + raise ValueError("model_type cannot be empty") + if model_type not in MODEL_TYPES.values(): + raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}") + + + logger.info(f"Initializing training with model_type={model_type}") + + try: + # Get absolute paths + current_dir = Path(__file__).parent.absolute() + train_script = current_dir / "train.py" + + if not train_script.exists(): + error_msg = f"Training script not found at {train_script}" + logger.error(error_msg) + return error_msg, "Training script not found" + + # Log paths for debugging + logger.info("Current working directory: %s", current_dir) + logger.info("Training script path: %s", train_script) + logger.info("Training data path: %s", TRAINING_PATH) + + videos_file, prompts_file = prepare_finetrainers_dataset() + if videos_file is None or prompts_file is None: + error_msg = "Failed to generate training lists" + logger.error(error_msg) + return error_msg, "Training preparation failed" + + video_count = sum(1 for _ in open(videos_file)) + logger.info(f"Generated training lists with {video_count} files") + + if video_count == 0: + error_msg = "No training files found" + logger.error(error_msg) + return error_msg, "No training data available" + + # Get config for selected model type + if model_type == "hunyuan_video": + config = TrainingConfig.hunyuan_video_lora( + data_path=str(TRAINING_PATH), + output_path=str(OUTPUT_PATH) + ) + else: # ltx_video + config = TrainingConfig.ltx_video_lora( + data_path=str(TRAINING_PATH), + output_path=str(OUTPUT_PATH) + ) + + # Update with UI parameters + config.train_epochs = int(num_epochs) + config.lora_rank = int(lora_rank) + config.lora_alpha = int(lora_alpha) + config.batch_size = int(batch_size) + config.lr = float(learning_rate) + config.checkpointing_steps = int(save_iterations) + + # Common settings for both models + config.mixed_precision = "bf16" + config.seed = 42 + config.gradient_checkpointing = True + config.enable_slicing = True + config.enable_tiling = True + config.caption_dropout_p = 0.05 + + validation_error = self.validate_training_config(config, model_type) + if validation_error: + error_msg = f"Configuration validation failed: {validation_error}" + logger.error(error_msg) + return "Error: Invalid configuration", error_msg + + # Configure accelerate parameters + accelerate_args = [ + "accelerate", "launch", + "--mixed_precision=bf16", + "--num_processes=1", + "--num_machines=1", + "--dynamo_backend=no" + ] + + accelerate_args.append(str(train_script)) + + # Convert config to command line arguments + config_args = config.to_args_list() + + + logger.debug("Generated args list: %s", config_args) + + # Log the full command for debugging + command_str = ' '.join(accelerate_args + config_args) + self.append_log(f"Command: {command_str}") + logger.info(f"Executing command: {command_str}") + + # Set environment variables + env = os.environ.copy() + env["NCCL_P2P_DISABLE"] = "1" + env["TORCH_NCCL_ENABLE_MONITORING"] = "0" + env["WANDB_MODE"] = "offline" + env["HF_API_TOKEN"] = HF_API_TOKEN + env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging + + # Start the training process + process = subprocess.Popen( + accelerate_args + config_args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + start_new_session=True, + env=env, + cwd=str(current_dir), + bufsize=1, + universal_newlines=True + ) + + logger.info(f"Started process with PID: {process.pid}") + + with open(self.pid_file, 'w') as f: + f.write(str(process.pid)) + + # Save session info including repo_id for later hub upload + self.save_session({ + "model_type": model_type, + "lora_rank": lora_rank, + "lora_alpha": lora_alpha, + "num_epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "save_iterations": save_iterations, + "repo_id": repo_id, + "start_time": datetime.now().isoformat() + }) + + # Update initial training status + total_steps = num_epochs * (max(1, video_count) // batch_size) + self.save_status( + state='running', + epoch=0, + step=0, + total_steps=total_steps, + loss=0.0, + total_epochs=num_epochs, + message='Training started', + repo_id=repo_id, + model_type=model_type + ) + + # Start monitoring process output + self._start_log_monitor(process) + + success_msg = f"Started training {model_type} model" + self.append_log(success_msg) + logger.info(success_msg) + + return success_msg, self.get_logs() + + except Exception as e: + error_msg = f"Error starting training: {str(e)}" + self.append_log(error_msg) + logger.exception("Training startup failed") + traceback.print_exc() # Added for better error debugging + return "Error starting training", error_msg + + + def stop_training(self) -> Tuple[str, str]: + """Stop training process""" + if not self.pid_file.exists(): + return "No training process found", self.get_logs() + + try: + with open(self.pid_file, 'r') as f: + pid = int(f.read().strip()) + + if psutil.pid_exists(pid): + os.killpg(os.getpgid(pid), signal.SIGTERM) + + if self.pid_file.exists(): + self.pid_file.unlink() + + self.append_log("Training process stopped") + self.save_status(state='stopped', message='Training stopped') + + return "Training stopped successfully", self.get_logs() + + except Exception as e: + error_msg = f"Error stopping training: {str(e)}" + self.append_log(error_msg) + if self.pid_file.exists(): + self.pid_file.unlink() + return "Error stopping training", error_msg + + def pause_training(self) -> Tuple[str, str]: + """Pause training process by sending SIGUSR1""" + if not self.is_training_running(): + return "No training process found", self.get_logs() + + try: + with open(self.pid_file, 'r') as f: + pid = int(f.read().strip()) + + if psutil.pid_exists(pid): + os.kill(pid, signal.SIGUSR1) # Signal to pause + self.save_status(state='paused', message='Training paused') + self.append_log("Training paused") + + return "Training paused", self.get_logs() + + except Exception as e: + error_msg = f"Error pausing training: {str(e)}" + self.append_log(error_msg) + return "Error pausing training", error_msg + + def resume_training(self) -> Tuple[str, str]: + """Resume training process by sending SIGUSR2""" + if not self.is_training_running(): + return "No training process found", self.get_logs() + + try: + with open(self.pid_file, 'r') as f: + pid = int(f.read().strip()) + + if psutil.pid_exists(pid): + os.kill(pid, signal.SIGUSR2) # Signal to resume + self.save_status(state='running', message='Training resumed') + self.append_log("Training resumed") + + return "Training resumed", self.get_logs() + + except Exception as e: + error_msg = f"Error resuming training: {str(e)}" + self.append_log(error_msg) + return "Error resuming training", error_msg + + def is_training_running(self) -> bool: + """Check if training is currently running""" + if not self.pid_file.exists(): + return False + + try: + with open(self.pid_file, 'r') as f: + pid = int(f.read().strip()) + return psutil.pid_exists(pid) + except: + return False + + def clear_training_data(self) -> str: + """Clear all training data""" + if self.is_training_running(): + return gr.Error("Cannot clear data while training is running") + + try: + for file in TRAINING_VIDEOS_PATH.glob("*.*"): + file.unlink() + for file in TRAINING_PATH.glob("*.*"): + file.unlink() + + self.append_log("Cleared all training data") + return "Training data cleared successfully" + + except Exception as e: + error_msg = f"Error clearing training data: {str(e)}" + self.append_log(error_msg) + return error_msg + + def save_status(self, state: str, **kwargs) -> None: + """Save current training status""" + status = { + 'state': state, + 'timestamp': datetime.now().isoformat(), + **kwargs + } + with open(self.status_file, 'w') as f: + json.dump(status, f, indent=2) + + def _start_log_monitor(self, process: subprocess.Popen) -> None: + """Start monitoring process output for logs""" + + + def monitor(): + self.append_log("Starting log monitor thread") + + def read_stream(stream, is_error=False): + if stream: + output = stream.readline() + if output: + # Remove decode() since output is already a string due to universal_newlines=True + line = output.strip() + if is_error: + #self.append_log(f"ERROR: {line}") + #logger.error(line) + #logger.info(line) + self.append_log(line) + else: + self.append_log(line) + # Parse metrics only from stdout + metrics = parse_training_log(line) + if metrics: + status = self.get_status() + status.update(metrics) + self.save_status(**status) + return True + return False + + # Use select to monitor both stdout and stderr + while process.poll() is None: + outputs = [process.stdout, process.stderr] + readable, _, _ = select.select(outputs, [], [], 1.0) + + for stream in readable: + is_error = (stream == process.stderr) + read_stream(stream, is_error) + + # Process any remaining output after process ends + while read_stream(process.stdout): + pass + while read_stream(process.stderr, True): + pass + + # Process finished + return_code = process.poll() + if return_code == 0: + success_msg = "Training completed successfully" + self.append_log(success_msg) + gr.Info(success_msg) + self.save_status(state='completed', message=success_msg) + + # Upload final model if repository was specified + session = self.load_session() + if session and session['params'].get('repo_id'): + repo_id = session['params']['repo_id'] + latest_run = max(Path(OUTPUT_PATH).glob('*'), key=os.path.getmtime) + if self.upload_to_hub(latest_run, repo_id): + self.append_log(f"Model uploaded to {repo_id}") + else: + self.append_log("Failed to upload model to hub") + else: + error_msg = f"Training failed with return code {return_code}" + self.append_log(error_msg) + logger.error(error_msg) + self.save_status(state='error', message=error_msg) + + # Clean up PID file + if self.pid_file.exists(): + self.pid_file.unlink() + + monitor_thread = threading.Thread(target=monitor) + monitor_thread.daemon = True + monitor_thread.start() + + def upload_to_hub(self, model_path: Path, repo_id: str) -> bool: + """Upload model to Hugging Face Hub + + Args: + model_path: Path to model files + repo_id: Repository ID (username/model-name) + + Returns: + bool: Whether upload was successful + """ + try: + token = os.getenv("HF_API_TOKEN") + if not token: + self.append_log("Error: HF_API_TOKEN not set") + return False + + # Create or get repo + create_repo(repo_id, token=token, repo_type="model", exist_ok=True) + + # Upload files + upload_folder( + folder_path=str(OUTPUT_PATH), + repo_id=repo_id, + repo_type="model", + commit_message="Training completed" + ) + + return True + except Exception as e: + self.append_log(f"Error uploading to hub: {str(e)}") + return False + + def get_model_output_safetensors(self) -> str: + """Return the path to the model safetensors + + + Returns: + Path to created ZIP file + """ + + model_output_safetensors_path = OUTPUT_PATH / "pytorch_lora_weights.safetensors" + return str(model_output_safetensors_path) + + def create_training_dataset_zip(self) -> str: + """Create a ZIP file containing all training data + + + Returns: + Path to created ZIP file + """ + # Create temporary zip file + with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip: + temp_zip_path = str(temp_zip.name) + print(f"Creating zip file for {TRAINING_PATH}..") + try: + make_archive(TRAINING_PATH, temp_zip_path) + print(f"Zip file created!") + return temp_zip_path + except Exception as e: + print(f"Failed to create zip: {str(e)}") + raise gr.Error(f"Failed to create zip: {str(e)}") + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94b644e4eb2a2addc0e71d415279ec84716ecf8e --- /dev/null +++ b/utils.py @@ -0,0 +1,270 @@ +import os +import shutil +from huggingface_hub import HfApi, create_repo +from pathlib import Path +import json +import re +from typing import Any, Optional, Dict, List, Union, Tuple + +def make_archive(source: str | Path, destination: str | Path): + source = str(source) + destination = str(destination) + #print(f"make_archive({source}, {destination})") + base = os.path.basename(destination) + name = base.split('.')[0] + format = base.split('.')[1] + archive_from = os.path.dirname(source) + archive_to = os.path.basename(source.strip(os.sep)) + shutil.make_archive(name, format, archive_from, archive_to) + shutil.move('%s.%s'%(name,format), destination) + +def extract_scene_info(filename: str) -> Tuple[str, Optional[int]]: + """Extract base name and scene number from filename + + Args: + filename: Input filename like "my_cool_video_1___001.mp4" + + Returns: + Tuple of (base_name, scene_number) + e.g. ("my_cool_video_1", 1) + """ + # Match numbers at the end of the filename before extension + match = re.search(r'(.+?)___(\d+)$', Path(filename).stem) + if match: + return match.group(1), int(match.group(2)) + return Path(filename).stem, None + +def is_image_file(file_path: Path) -> bool: + """Check if file is an image based on extension + + Args: + file_path: Path to check + + Returns: + bool: True if file has image extension + """ + image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic'} + return file_path.suffix.lower() in image_extensions + +def is_video_file(file_path: Path) -> bool: + """Check if file is a video based on extension + + Args: + file_path: Path to check + + Returns: + bool: True if file has video extension + """ + video_extensions = {'.mp4', '.webm'} + return file_path.suffix.lower() in video_extensions + +def parse_bool_env(env_value: Optional[str]) -> bool: + """Parse environment variable string to boolean + + Handles various true/false string representations: + - True: "true", "True", "TRUE", "1", etc + - False: "false", "False", "FALSE", "0", "", None + """ + if not env_value: + return False + return str(env_value).lower() in ('true', '1', 't', 'y', 'yes') + +def validate_model_repo(repo_id: str) -> Dict[str, str]: + """Validate HuggingFace model repository name + + Args: + repo_id: Repository ID in format "username/model-name" + + Returns: + Dict with error message if invalid, or None if valid + """ + if not repo_id: + return {"error": "Repository name is required"} + + if "/" not in repo_id: + return {"error": "Repository name must be in format username/model-name"} + + # Check characters + invalid_chars = set('<>:"/\\|?*') + if any(c in repo_id for c in invalid_chars): + return {"error": "Repository name contains invalid characters"} + + return {"error": None} + +def save_to_hub(model_path: Path, repo_id: str, token: str, commit_message: str = "Update model") -> bool: + """Save model files to Hugging Face Hub + + Args: + model_path: Path to model files + repo_id: Repository ID (username/model-name) + token: HuggingFace API token + commit_message: Optional commit message + + Returns: + bool: True if successful, False if failed + """ + try: + api = HfApi(token=token) + + # Validate repo_id + validation = validate_model_repo(repo_id) + if validation["error"]: + return False + + # Create or get repo + try: + create_repo(repo_id, token=token, repo_type="model", exist_ok=True) + except Exception as e: + print(f"Error creating repo: {e}") + return False + + # Upload all files + api.upload_folder( + folder_path=str(model_path), + repo_id=repo_id, + repo_type="model", + commit_message=commit_message + ) + + return True + except Exception as e: + print(f"Error uploading to hub: {e}") + return False + +def parse_training_log(line: str) -> Dict: + """Parse a training log line for metrics + + Args: + line: Log line from training output + + Returns: + Dict with parsed metrics (epoch, step, loss, etc) + """ + metrics = {} + + try: + # Extract step/epoch info + if "step=" in line: + step = int(line.split("step=")[1].split()[0].strip(",")) + metrics["step"] = step + + if "epoch=" in line: + epoch = int(line.split("epoch=")[1].split()[0].strip(",")) + metrics["epoch"] = epoch + + if "loss=" in line: + loss = float(line.split("loss=")[1].split()[0].strip(",")) + metrics["loss"] = loss + + if "lr=" in line: + lr = float(line.split("lr=")[1].split()[0].strip(",")) + metrics["learning_rate"] = lr + except: + pass + + return metrics + +def format_size(size_bytes: int) -> str: + """Format bytes into human readable string with appropriate unit + + Args: + size_bytes: Size in bytes + + Returns: + Formatted string (e.g. "1.5 Gb") + """ + units = ['bytes', 'Kb', 'Mb', 'Gb', 'Tb'] + unit_index = 0 + size = float(size_bytes) + + while size >= 1024 and unit_index < len(units) - 1: + size /= 1024 + unit_index += 1 + + # Special case for bytes - no decimal places + if unit_index == 0: + return f"{int(size)} {units[unit_index]}" + + return f"{size:.1f} {units[unit_index]}" + + +def count_media_files(path: Path) -> Tuple[int, int, int]: + """Count videos and images in directory + + Args: + path: Directory to scan + + Returns: + Tuple of (video_count, image_count, total_size) + """ + video_count = 0 + image_count = 0 + total_size = 0 + + for file in path.glob("*"): + # Skip hidden files and caption files + if file.name.startswith('.') or file.suffix.lower() == '.txt': + continue + + if is_video_file(file): + video_count += 1 + total_size += file.stat().st_size + elif is_image_file(file): + image_count += 1 + total_size += file.stat().st_size + + return video_count, image_count, total_size + +def format_media_title(action: str, video_count: int, image_count: int, total_size: int) -> str: + """Format title with media counts and size + + Args: + action: Action (eg "split", "caption") + video_count: Number of videos + image_count: Number of images + total_size: Total size in bytes + + Returns: + Formatted title string + """ + parts = [] + if image_count > 0: + parts.append(f"{image_count:,} photo{'s' if image_count != 1 else ''}") + if video_count > 0: + parts.append(f"{video_count:,} video{'s' if video_count != 1 else ''}") + + if not parts: + return f"## 0 files to {action} (0 bytes)" + + return f"## {' and '.join(parts)} to {action} ({format_size(total_size)})" + +def add_prefix_to_caption(caption: str, prefix: str) -> str: + """Add prefix to caption if not already present""" + if not prefix or not caption: + return caption + if caption.startswith(prefix): + return caption + return f"{prefix}{caption}" + +def format_time(seconds: float) -> str: + """Format time duration in seconds to human readable string + + Args: + seconds: Time in seconds + + Returns: + Formatted string (e.g. "2h 30m 45s") + """ + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + parts = [] + if hours > 0: + parts.append(f"{hours}h") + if minutes > 0: + parts.append(f"{minutes}m") + if secs > 0 or not parts: + parts.append(f"{secs}s") + + return " ".join(parts) \ No newline at end of file diff --git a/video_preprocessing.py b/video_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..37fbd9d2e8d26523416670fd293ea7c046b549c6 --- /dev/null +++ b/video_preprocessing.py @@ -0,0 +1,132 @@ +import cv2 +import numpy as np +from pathlib import Path +import subprocess + +def detect_black_bars(video_path: Path) -> tuple[int, int, int, int]: + """Detect black bars in video by analyzing first few frames + + Args: + video_path: Path to video file + + Returns: + Tuple of (top, bottom, left, right) crop values + """ + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise ValueError(f"Could not open video: {video_path}") + + # Read first few frames to get stable detection + frames_to_check = 5 + frames = [] + + for _ in range(frames_to_check): + ret, frame = cap.read() + if not ret: + break + frames.append(frame) + + cap.release() + + if not frames: + raise ValueError("Could not read any frames from video") + + # Convert frames to grayscale and find average + gray_frames = [cv2.cvtColor(f, cv2.COLOR_BGR2GRAY) for f in frames] + avg_frame = np.mean(gray_frames, axis=0) + + # Threshold to detect black regions (adjust sensitivity if needed) + threshold = 20 + black_mask = avg_frame < threshold + + # Find black bars by analyzing row/column means + row_means = np.mean(black_mask, axis=1) + col_means = np.mean(black_mask, axis=0) + + # Detect edges where black bars end (using high threshold to avoid false positives) + black_threshold = 0.95 # 95% of pixels in row/col must be black + + # Find top and bottom crops + top_crop = 0 + bottom_crop = black_mask.shape[0] + + for i, mean in enumerate(row_means): + if mean > black_threshold: + top_crop = i + 1 + else: + break + + for i, mean in enumerate(reversed(row_means)): + if mean > black_threshold: + bottom_crop = black_mask.shape[0] - i - 1 + else: + break + + # Find left and right crops + left_crop = 0 + right_crop = black_mask.shape[1] + + for i, mean in enumerate(col_means): + if mean > black_threshold: + left_crop = i + 1 + else: + break + + for i, mean in enumerate(reversed(col_means)): + if mean > black_threshold: + right_crop = black_mask.shape[1] - i - 1 + else: + break + + return top_crop, bottom_crop, left_crop, right_crop + +def remove_black_bars(input_path: Path, output_path: Path) -> bool: + """Remove black bars from video using FFmpeg + + Args: + input_path: Path to input video + output_path: Path to save processed video + + Returns: + bool: True if successful, False if no cropping needed + """ + try: + # Detect black bars + top, bottom, left, right = detect_black_bars(input_path) + + # Get video dimensions using OpenCV + cap = cv2.VideoCapture(str(input_path)) + if not cap.isOpened(): + raise ValueError(f"Could not open video: {input_path}") + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + # If no significant black bars detected, return False + if top < 10 and bottom > height - 10 and \ + left < 10 and right > width - 10: + return False + + # Calculate crop dimensions + crop_height = bottom - top + crop_width = right - left + + if crop_height <= 0 or crop_width <= 0: + return False + + # Use FFmpeg to crop and save video + cmd = [ + 'ffmpeg', '-i', str(input_path), + '-vf', f'crop={crop_width}:{crop_height}:{left}:{top}', + '-c:a', 'copy', # Copy audio stream + '-y', # Overwrite output + str(output_path) + ] + + subprocess.run(cmd, check=True, capture_output=True) + return True + + except Exception as e: + print(f"Error removing black bars: {e}") + return False \ No newline at end of file